diff --git a/CHANGELOG.md b/CHANGELOG.md index 90bb8ddc2..b94054b1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `TreeNode.parent` -- a read-only property for accessing a node's parent https://github.com/Textualize/textual/issues/1397 - Added public `TreeNode` label access via `TreeNode.label` https://github.com/Textualize/textual/issues/1396 - Added read-only public access to the children of a `TreeNode` via `TreeNode.children` https://github.com/Textualize/textual/issues/1398 +- Added `Tree.get_node_by_id` to allow getting a node by its ID ### Changed diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index bc431e143..f6c3be5b0 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -513,6 +513,17 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): else: return line.node + def get_node_by_id(self, node_id: NodeID) -> TreeNode[TreeDataType]: + """Get a tree node by its ID. + + Args: + node_id (NodeID): The ID of the node to get. + + Returns: + TreeNode[TreeDataType]: The node associated with that ID. + """ + return self._nodes[node_id] + def validate_cursor_line(self, value: int) -> int: """Prevent cursor line from going outside of range.""" return clamp(value, 0, len(self._tree_lines) - 1) diff --git a/tests/tree/test_tree_get_node_by_id.py b/tests/tree/test_tree_get_node_by_id.py new file mode 100644 index 000000000..fbe01a150 --- /dev/null +++ b/tests/tree/test_tree_get_node_by_id.py @@ -0,0 +1,16 @@ +import pytest +from typing import cast +from textual.widgets import Tree +from textual.widgets._tree import NodeID + + +def test_get_tree_node_by_id() -> None: + """It should be possible to get a TreeNode by its ID.""" + tree = Tree[None]("Anakin") + child = tree.root.add("Leia") + grandchild = child.add("Ben") + assert tree.get_node_by_id(tree.root.id).id == tree.root.id + assert tree.get_node_by_id(child.id).id == child.id + assert tree.get_node_by_id(grandchild.id).id == grandchild.id + with pytest.raises(KeyError): + tree.get_node_by_id(cast(NodeID, grandchild.id + 1000))