diff --git a/CHANGELOG.md b/CHANGELOG.md index 21c79a734..dfd75338d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `TreeNode.parent` -- a read-only property for accessing a node's parent https://github.com/Textualize/textual/issues/1397 - Added public `TreeNode` label access via `TreeNode.label` https://github.com/Textualize/textual/issues/1396 - Added read-only public access to the children of a `TreeNode` via `TreeNode.children` https://github.com/Textualize/textual/issues/1398 +- Added a `Tree.NodeHighlighted` message, giving a `on_tree_node_highlighted` event handler https://github.com/Textualize/textual/issues/1400 ### Changed diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index bc431e143..a41928941 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -330,44 +330,46 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): ), } - class NodeSelected(Generic[EventTreeDataType], Message, bubble=True): + class NodeMessage(Generic[EventTreeDataType], Message, bubble=True): + """Base class for events sent when something happens with a node. + + Attributes: + TreeNode[EventTreeDataType]: The node involved in the event. + """ + + def __init__( + self, sender: MessageTarget, node: TreeNode[EventTreeDataType] + ) -> None: + self.node = node + super().__init__(sender) + + class NodeSelected(NodeMessage[EventTreeDataType]): """Event sent when a node is selected. Attributes: TreeNode[EventTreeDataType]: The node that was selected. """ - def __init__( - self, sender: MessageTarget, node: TreeNode[EventTreeDataType] - ) -> None: - self.node = node - super().__init__(sender) - - class NodeExpanded(Generic[EventTreeDataType], Message, bubble=True): + class NodeExpanded(NodeMessage[EventTreeDataType]): """Event sent when a node is expanded. Attributes: TreeNode[EventTreeDataType]: The node that was expanded. """ - def __init__( - self, sender: MessageTarget, node: TreeNode[EventTreeDataType] - ) -> None: - self.node = node - super().__init__(sender) - - class NodeCollapsed(Generic[EventTreeDataType], Message, bubble=True): + class NodeCollapsed(NodeMessage[EventTreeDataType]): """Event sent when a node is collapsed. Attributes: TreeNode[EventTreeDataType]: The node that was collapsed. """ - def __init__( - self, sender: MessageTarget, node: TreeNode[EventTreeDataType] - ) -> None: - self.node = node - super().__init__(sender) + class NodeHighlighted(NodeMessage[EventTreeDataType]): + """Event sent when a node is highlighted. + + Attributes: + TreeNode[EventTreeDataType]: The node that was collapsed. + """ def __init__( self, @@ -577,6 +579,8 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): self._refresh_node(node) node._selected = True self._cursor_node = node + if previous_node != node: + self.post_message_no_wait(self.NodeHighlighted(self, node)) def watch_guide_depth(self, guide_depth: int) -> None: self._invalidate() diff --git a/tests/tree/test_tree_messages.py b/tests/tree/test_tree_messages.py new file mode 100644 index 000000000..f271d4e42 --- /dev/null +++ b/tests/tree/test_tree_messages.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any +from textual.app import App, ComposeResult +from textual.widgets import Tree +from textual.message import Message + + +class MyTree(Tree[None]): + pass + + +class TreeApp(App[None]): + """Test tree app.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.messages: list[str] = [] + + def compose(self) -> ComposeResult: + """Compose the child widgets.""" + yield MyTree("Root") + + def on_mount(self) -> None: + self.query_one(MyTree).root.add("Child") + self.query_one(MyTree).focus() + + def record(self, event: Message) -> None: + self.messages.append(event.__class__.__name__) + + def on_tree_node_selected(self, event: Tree.NodeSelected[None]) -> None: + self.record(event) + + def on_tree_node_expanded(self, event: Tree.NodeExpanded[None]) -> None: + self.record(event) + + def on_tree_node_collapsed(self, event: Tree.NodeCollapsed[None]) -> None: + self.record(event) + + def on_tree_node_highlighted(self, event: Tree.NodeHighlighted[None]) -> None: + self.record(event) + + +async def test_tree_node_selected_message() -> None: + """Selecting a node should result in a selected message being emitted.""" + async with TreeApp().run_test() as pilot: + await pilot.press("enter") + await pilot.pause(2 / 100) + assert pilot.app.messages == ["NodeExpanded", "NodeSelected"] + + +async def test_tree_node_expanded_message() -> None: + """Expanding a node should result in an expanded message being emitted.""" + async with TreeApp().run_test() as pilot: + await pilot.press("enter") + await pilot.pause(2 / 100) + assert pilot.app.messages == ["NodeExpanded", "NodeSelected"] + + +async def test_tree_node_collapsed_message() -> None: + """Collapsing a node should result in a collapsed message being emitted.""" + async with TreeApp().run_test() as pilot: + await pilot.press("enter", "enter") + await pilot.pause(2 / 100) + assert pilot.app.messages == [ + "NodeExpanded", + "NodeSelected", + "NodeCollapsed", + "NodeSelected", + ] + + +async def test_tree_node_highlighted_message() -> None: + """Highlighting a node should result in a highlighted message being emitted.""" + async with TreeApp().run_test() as pilot: + await pilot.press("enter", "down") + await pilot.pause(2 / 100) + assert pilot.app.messages == ["NodeExpanded", "NodeSelected", "NodeHighlighted"]