From 3d35a602b5cd471101bd812a9da7db61a319c2c1 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Wed, 16 Nov 2022 15:15:42 +0000 Subject: [PATCH] added node expanding --- src/textual/widgets/_tree.py | 56 +++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index 7ce6b957b..2d42a7d3c 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -15,7 +15,7 @@ from ..geometry import clamp, Region, Size from .._loop import loop_last from .._cache import LRUCache from ..message import Message -from ..reactive import reactive +from ..reactive import reactive, var from .._segment_tools import line_crop, line_pad from .._types import MessageTarget from .._typing import TypeAlias @@ -47,7 +47,7 @@ class _TreeLine: class TreeNode(Generic[TreeDataType]): def __init__( self, - tree: Tree, + tree: Tree[TreeDataType], parent: TreeNode[TreeDataType] | None, id: NodeID, label: Text, @@ -60,7 +60,7 @@ class TreeNode(Generic[TreeDataType]): self.id = id self.label = label self.data: TreeDataType = data - self.expanded = expanded + self._expanded = expanded self.children: list[TreeNode] = [] self._hover = False @@ -70,6 +70,15 @@ class TreeNode(Generic[TreeDataType]): yield self.label.plain yield self.data + @property + def expanded(self) -> bool: + return self._expanded + + @expanded.setter + def expanded(self, expanded: bool) -> None: + self._expanded = expanded + self._tree.invalidate() + @property def last(self) -> bool: """Check if this is the last child. @@ -101,7 +110,7 @@ class TreeNode(Generic[TreeDataType]): else: text_label = label node = self._tree._add_node(self, text_label, data) - node.expanded = expanded + node._expanded = expanded self.children.append(node) self._tree.invalidate() return node @@ -165,9 +174,10 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): "tree--highlight-line", } - hover_line: reactive[int] = reactive(-1, repaint=False) - cursor_line: reactive[int] = reactive(0, repaint=False) - guide_depth: reactive[int] = reactive(4, repaint=False, init=False) + hover_line = var(-1) + cursor_line = var(-1) + guide_depth = var(4, init=False) + auto_expand = var(True) LINES: dict[str, tuple[str, str, str, str]] = { "default": ( @@ -216,7 +226,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): self._nodes: dict[NodeID, TreeNode[TreeDataType]] = {} self._current_id = 0 self.root = self._add_node(None, text_label, data) - self.root.expanded = True + self.root._expanded = True self._line_cache: LRUCache[LineCacheKey, list[Segment]] = LRUCache(1024) self._tree_lines_cached: list[_TreeLine] | None = None @@ -228,6 +238,17 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): self._updates += 1 return node + def render_label(self, node: TreeNode[TreeDataType]) -> Text: + """Render a label for the given node. Override this to modify how labels are rendered. + + Args: + node (TreeNode[TreeDataType]): A tree node. + + Returns: + Text: A Rich Text object containing the label. + """ + return node.label + def clear(self) -> None: """Clear all nodes under root.""" self._tree_lines_cached = None @@ -349,7 +370,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): def add_node(path: list[TreeNode], node: TreeNode, last: bool) -> None: child_path = [*path, node] add_line(_TreeLine(child_path, last)) - if node.expanded: + if node._expanded: for last, child in loop_last(node.children): add_node(child_path, child, last=last) @@ -455,7 +476,8 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): guides.append(terminator, style=guide_style) else: guides.append(cross, style=guide_style) - label = line.path[-1].label.copy() + + label = self.render_label(line.path[-1]).copy() label.stylize(self.get_component_rich_style("tree--label", partial=True)) if self.hover_line == y: label.stylize( @@ -478,14 +500,18 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): self.invalidate() self._line_cache.grow(self.size.height * 2) - def _on_click(self, event: events.Click) -> None: + async def _on_click(self, event: events.Click) -> None: meta = event.style.meta if "line" in meta: - self.cursor_line = meta["line"] + cursor_line = meta["line"] + if self.cursor_line == cursor_line: + await self.action("select_cursor") + else: + self.cursor_line = cursor_line def action_cursor_up(self) -> None: if self.cursor_line == -1: - self.cursor_line = len(self._tree_lines) + self.cursor_line = len(self._tree_lines) - 1 else: self.cursor_line -= 1 @@ -498,5 +524,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True): except IndexError: pass else: + node = line.path[-1] + if self.auto_expand: + node.expanded = not node.expanded self.emit_no_wait(self.NodeSelected(self, line.path[-1])) - self.app.bell()