diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index f6c3be5b0..0f1d99af7 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -513,6 +513,9 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): else: return line.node + class UnknownID(Exception): + """Exception raised when referring to an unknown `TreeNode` ID.""" + def get_node_by_id(self, node_id: NodeID) -> TreeNode[TreeDataType]: """Get a tree node by its ID. @@ -521,8 +524,14 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): Returns: TreeNode[TreeDataType]: The node associated with that ID. + + Raises: + Tree.UnknownID: Raised if the `TreeNode` ID is unknown. """ - return self._nodes[node_id] + try: + return self._nodes[node_id] + except KeyError: + raise self.UnknownID(f"Unknown TreeNode ID: {node_id}") from None def validate_cursor_line(self, value: int) -> int: """Prevent cursor line from going outside of range.""" diff --git a/tests/tree/test_tree_get_node_by_id.py b/tests/tree/test_tree_get_node_by_id.py index fbe01a150..d9eca97af 100644 --- a/tests/tree/test_tree_get_node_by_id.py +++ b/tests/tree/test_tree_get_node_by_id.py @@ -12,5 +12,5 @@ def test_get_tree_node_by_id() -> None: assert tree.get_node_by_id(tree.root.id).id == tree.root.id assert tree.get_node_by_id(child.id).id == child.id assert tree.get_node_by_id(grandchild.id).id == grandchild.id - with pytest.raises(KeyError): + with pytest.raises(Tree.UnknownID): tree.get_node_by_id(cast(NodeID, grandchild.id + 1000))