added node expanding

This commit is contained in:
Will McGugan
2022-11-16 15:15:42 +00:00
parent 22f37871d9
commit 3d35a602b5

View File

@@ -15,7 +15,7 @@ from ..geometry import clamp, Region, Size
from .._loop import loop_last
from .._cache import LRUCache
from ..message import Message
from ..reactive import reactive
from ..reactive import reactive, var
from .._segment_tools import line_crop, line_pad
from .._types import MessageTarget
from .._typing import TypeAlias
@@ -47,7 +47,7 @@ class _TreeLine:
class TreeNode(Generic[TreeDataType]):
def __init__(
self,
tree: Tree,
tree: Tree[TreeDataType],
parent: TreeNode[TreeDataType] | None,
id: NodeID,
label: Text,
@@ -60,7 +60,7 @@ class TreeNode(Generic[TreeDataType]):
self.id = id
self.label = label
self.data: TreeDataType = data
self.expanded = expanded
self._expanded = expanded
self.children: list[TreeNode] = []
self._hover = False
@@ -70,6 +70,15 @@ class TreeNode(Generic[TreeDataType]):
yield self.label.plain
yield self.data
@property
def expanded(self) -> bool:
return self._expanded
@expanded.setter
def expanded(self, expanded: bool) -> None:
self._expanded = expanded
self._tree.invalidate()
@property
def last(self) -> bool:
"""Check if this is the last child.
@@ -101,7 +110,7 @@ class TreeNode(Generic[TreeDataType]):
else:
text_label = label
node = self._tree._add_node(self, text_label, data)
node.expanded = expanded
node._expanded = expanded
self.children.append(node)
self._tree.invalidate()
return node
@@ -165,9 +174,10 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
"tree--highlight-line",
}
hover_line: reactive[int] = reactive(-1, repaint=False)
cursor_line: reactive[int] = reactive(0, repaint=False)
guide_depth: reactive[int] = reactive(4, repaint=False, init=False)
hover_line = var(-1)
cursor_line = var(-1)
guide_depth = var(4, init=False)
auto_expand = var(True)
LINES: dict[str, tuple[str, str, str, str]] = {
"default": (
@@ -216,7 +226,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._nodes: dict[NodeID, TreeNode[TreeDataType]] = {}
self._current_id = 0
self.root = self._add_node(None, text_label, data)
self.root.expanded = True
self.root._expanded = True
self._line_cache: LRUCache[LineCacheKey, list[Segment]] = LRUCache(1024)
self._tree_lines_cached: list[_TreeLine] | None = None
@@ -228,6 +238,17 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._updates += 1
return node
def render_label(self, node: TreeNode[TreeDataType]) -> Text:
"""Render a label for the given node. Override this to modify how labels are rendered.
Args:
node (TreeNode[TreeDataType]): A tree node.
Returns:
Text: A Rich Text object containing the label.
"""
return node.label
def clear(self) -> None:
"""Clear all nodes under root."""
self._tree_lines_cached = None
@@ -349,7 +370,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
def add_node(path: list[TreeNode], node: TreeNode, last: bool) -> None:
child_path = [*path, node]
add_line(_TreeLine(child_path, last))
if node.expanded:
if node._expanded:
for last, child in loop_last(node.children):
add_node(child_path, child, last=last)
@@ -455,7 +476,8 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
guides.append(terminator, style=guide_style)
else:
guides.append(cross, style=guide_style)
label = line.path[-1].label.copy()
label = self.render_label(line.path[-1]).copy()
label.stylize(self.get_component_rich_style("tree--label", partial=True))
if self.hover_line == y:
label.stylize(
@@ -478,14 +500,18 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self.invalidate()
self._line_cache.grow(self.size.height * 2)
def _on_click(self, event: events.Click) -> None:
async def _on_click(self, event: events.Click) -> None:
meta = event.style.meta
if "line" in meta:
self.cursor_line = meta["line"]
cursor_line = meta["line"]
if self.cursor_line == cursor_line:
await self.action("select_cursor")
else:
self.cursor_line = cursor_line
def action_cursor_up(self) -> None:
if self.cursor_line == -1:
self.cursor_line = len(self._tree_lines)
self.cursor_line = len(self._tree_lines) - 1
else:
self.cursor_line -= 1
@@ -498,5 +524,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
except IndexError:
pass
else:
node = line.path[-1]
if self.auto_expand:
node.expanded = not node.expanded
self.emit_no_wait(self.NodeSelected(self, line.path[-1]))
self.app.bell()