diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index 2f3d073cf..8f2304ae8 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -560,12 +560,17 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): label = self.render_label(node, NULL_STYLE, NULL_STYLE) return label.cell_len - def clear(self) -> None: - """Clear all nodes under root.""" + def clear(self, label: TextType | None = None) -> None: + """Clear all nodes under root. + + Args: + label: An optional new label for the root node. If not provided + the current root node's label will be used. + """ self._line_cache.clear() self._tree_lines_cached = None self._current_id = 0 - root_label = self.root._label + root_label = self.root._label if label is None else label root_data = self.root.data self.root = TreeNode( self, diff --git a/tests/tree/test_tree_clearing.py b/tests/tree/test_tree_clearing.py new file mode 100644 index 000000000..56a4515d6 --- /dev/null +++ b/tests/tree/test_tree_clearing.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from textual.app import App, ComposeResult +from textual.widgets import Tree + + +class VerseBody: + pass + + +class VerseStar(VerseBody): + pass + + +class VersePlanet(VerseBody): + pass + + +class VerseMoon(VerseBody): + pass + + +class TestTree(Tree[VerseBody]): + pass + + +class TreeClearApp(App[None]): + """Tree clearing test app.""" + + def compose(self) -> ComposeResult: + yield TestTree("White Sun", data=VerseStar()) + + def on_mount(self) -> None: + tree = self.query_one(TestTree) + node = tree.root.add("Londinium", VersePlanet()) + node.add_leaf("Balkerne", VerseMoon()) + node.add_leaf("Colchester", VerseMoon()) + node = tree.root.add("Sihnon", VersePlanet()) + node.add_leaf("Airen", VerseMoon()) + node.add_leaf("Xiaojie", VerseMoon()) + + +async def test_tree_simple_clear() -> None: + """Clearing a tree should keep the old label and data.""" + async with TreeClearApp().run_test() as pilot: + tree = pilot.app.query_one(TestTree) + assert len(tree.root.children) > 1 + pilot.app.query_one(TestTree).clear() + assert len(tree.root.children) == 0 + assert str(tree.root.label) == "White Sun" + assert isinstance(tree.root.data, VerseStar) + + +async def test_tree_new_label_clear() -> None: + """Clearing a tree with a new label should use the new label and keep the old data.""" + async with TreeClearApp().run_test() as pilot: + tree = pilot.app.query_one(TestTree) + assert len(tree.root.children) > 1 + pilot.app.query_one(TestTree).clear("Jiangyin") + assert len(tree.root.children) == 0 + assert str(tree.root.label) == "Jiangyin" + assert isinstance(tree.root.data, VerseStar)