diff --git a/CHANGELOG.md b/CHANGELOG.md index bf55fa898..7e3a58db0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `OptionList.add_options` https://github.com/Textualize/textual/pull/2508 +### Added + +- Added `TreeNode.is_root` https://github.com/Textualize/textual/pull/2510 +- Added `TreeNode.remove_children` https://github.com/Textualize/textual/pull/2510 +- Added `TreeNode.remove` https://github.com/Textualize/textual/pull/2510 + ## [0.23.0] - 2023-05-03 ### Fixed diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index 85255cf1e..0f3998914 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -184,6 +184,11 @@ class TreeNode(Generic[TreeDataType]): self._parent._children and self._parent._children[-1] == self, ) + @property + def is_root(self) -> bool: + """Is this node the root of the tree?""" + return self == self._tree.root + @property def allow_expand(self) -> bool: """Is this node allowed to expand?""" @@ -344,6 +349,47 @@ class TreeNode(Generic[TreeDataType]): node = self.add(label, data, expand=False, allow_expand=False) return node + class RemoveRootError(Exception): + """Exception raised when trying to remove a tree's root node.""" + + def _remove_children(self) -> None: + """Remove child nodes of this node. + + Note: + This is the internal support method for `remove_children`. Call + `remove_children` to ensure the tree gets refreshed. + """ + for child in reversed(self._children): + child._remove() + + def _remove(self) -> None: + """Remove the current node and all its children. + + Note: + This is the internal support method for `remove`. Call `remove` + to ensure the tree gets refreshed. + """ + self._remove_children() + assert self._parent is not None + del self._parent._children[self._parent._children.index(self)] + del self._tree._tree_nodes[self.id] + + def remove(self) -> None: + """Remove this node from the tree. + + Raises: + TreeNode.RemoveRootError: If there is an attempt to remove the root. + """ + if self.is_root: + raise self.RemoveRootError("Attempt to remove the root node of a Tree.") + self._remove() + self._tree._invalidate() + + def remove_children(self) -> None: + """Remove any child nodes of this node.""" + self._remove_children() + self._tree._invalidate() + class Tree(Generic[TreeDataType], ScrollView, can_focus=True): """A widget for displaying and navigating data in a tree.""" @@ -814,6 +860,8 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): self._cursor_node = node if previous_node != node: self.post_message(self.NodeHighlighted(node)) + else: + self._cursor_node = None def watch_guide_depth(self, guide_depth: int) -> None: self._invalidate() diff --git a/tests/tree/test_tree_clearing.py b/tests/tree/test_tree_clearing.py index 87543c4c8..bd868ee6d 100644 --- a/tests/tree/test_tree_clearing.py +++ b/tests/tree/test_tree_clearing.py @@ -1,7 +1,10 @@ from __future__ import annotations +import pytest + from textual.app import App, ComposeResult from textual.widgets import Tree +from textual.widgets.tree import TreeNode class VerseBody: @@ -71,3 +74,37 @@ async def test_tree_reset_with_label_and_data() -> None: assert len(tree.root.children) == 0 assert str(tree.root.label) == "Jiangyin" assert isinstance(tree.root.data, VersePlanet) + + +async def test_remove_node(): + async with TreeClearApp().run_test() as pilot: + tree = pilot.app.query_one(VerseTree) + assert len(tree.root.children) == 2 + tree.root.children[0].remove() + assert len(tree.root.children) == 1 + + +async def test_remove_node_children(): + async with TreeClearApp().run_test() as pilot: + tree = pilot.app.query_one(VerseTree) + assert len(tree.root.children) == 2 + assert len(tree.root.children[0].children) == 2 + tree.root.children[0].remove_children() + assert len(tree.root.children) == 2 + assert len(tree.root.children[0].children) == 0 + + +async def test_tree_remove_children_of_root(): + """Test removing the children of the root.""" + async with TreeClearApp().run_test() as pilot: + tree = pilot.app.query_one(VerseTree) + assert len(tree.root.children) > 1 + tree.root.remove_children() + assert len(tree.root.children) == 0 + + +async def test_attempt_to_remove_root(): + """Attempting to remove the root should be an error.""" + async with TreeClearApp().run_test() as pilot: + with pytest.raises(TreeNode.RemoveRootError): + pilot.app.query_one(VerseTree).root.remove()