added API to nodes

This commit is contained in:
Will McGugan
2022-11-18 17:09:47 +00:00
parent fed9e7a939
commit dfd7b6c8d9
2 changed files with 90 additions and 37 deletions

View File

@@ -14,7 +14,11 @@ print(data)
class TreeApp(App):
BINDINGS = [("a", "add", "Add node")]
BINDINGS = [
("a", "add", "Add node"),
("c", "clear", "Clear"),
("t", "toggle_root", "Toggle root"),
]
def compose(self) -> ComposeResult:
yield Header()
@@ -24,7 +28,17 @@ class TreeApp(App):
def action_add(self) -> None:
tree = self.query_one(Tree)
tree.add_json(data)
json_node = tree.root.add("JSON")
tree.root.expand()
tree.add_json(json_node, data)
def action_clear(self) -> None:
tree = self.query_one(Tree)
tree.clear()
def action_toggle_root(self) -> None:
tree = self.query_one(Tree)
tree.show_root = not tree.show_root
if __name__ == "__main__":

View File

@@ -39,7 +39,7 @@ class _TreeLine:
return self.path[-1]
def get_line_width(self, guide_depth: int) -> int:
return (len(self.path)) + self.path[-1].label.cell_len - guide_depth
return (len(self.path)) + self.path[-1]._label.cell_len - guide_depth
@rich.repr.auto
@@ -58,7 +58,7 @@ class TreeNode(Generic[TreeDataType]):
self._tree = tree
self._parent = parent
self.id = id
self.label = label
self._label = label
self.data: TreeDataType = data if data is not None else tree._data_factory()
self._expanded = expanded
self.children: list[TreeNode] = []
@@ -68,22 +68,32 @@ class TreeNode(Generic[TreeDataType]):
self._allow_expand = allow_expand
def __rich_repr__(self) -> rich.repr.Result:
yield self.label.plain
yield self._label.plain
yield self.data
def _reset(self) -> None:
self._hover = False
self._selected = False
@property
def expanded(self) -> bool:
return self._expanded
@expanded.setter
def expanded(self, expanded: bool) -> None:
self._expanded = expanded
def expand(self) -> None:
"""Expand a node (show its children)."""
self._expanded = True
self._tree.invalidate()
def collapse(self) -> None:
"""Collapse the node (hide children)."""
self._expanded = False
self._tree.invalidate()
def toggle(self) -> None:
self._expanded = not self._expanded
self._tree.invalidate()
@property
def is_expanded(self) -> bool:
"""Check if the node is expanded."""
return self._expanded
@property
def last(self) -> bool:
"""Check if this is the last child.
@@ -102,7 +112,7 @@ class TreeNode(Generic[TreeDataType]):
label: TextType,
data: TreeDataType | None = None,
*,
expanded: bool = True,
expand: bool = False,
allow_expand: bool = True,
) -> TreeNode[TreeDataType]:
"""Add a node to the sub-tree.
@@ -120,7 +130,7 @@ class TreeNode(Generic[TreeDataType]):
else:
text_label = label
node = self._tree._add_node(self, text_label, data)
node._expanded = expanded
node._expanded = expand
node._allow_expand = allow_expand
self.children.append(node)
self._tree.invalidate()
@@ -237,7 +247,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._nodes: dict[NodeID, TreeNode[TreeDataType]] = {}
self._current_id = 0
self.root = self._add_node(None, text_label, data)
self.root._expanded = True
self._line_cache: LRUCache[LineCacheKey, list[Segment]] = LRUCache(1024)
self._tree_lines_cached: list[_TreeLine] | None = None
@@ -263,9 +273,10 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
parent: TreeNode[TreeDataType] | None,
label: Text,
data: TreeDataType | None,
expand: bool = False,
) -> TreeNode[TreeDataType]:
node_data = data if data is not None else self._data_factory()
node = TreeNode(self, parent, self._new_id(), label, node_data)
node = TreeNode(self, parent, self._new_id(), label, node_data, expanded=expand)
self._nodes[node.id] = node
self._updates += 1
return node
@@ -281,11 +292,14 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
Returns:
Text: A Rich Text object containing the label.
"""
node_label = node.label.copy()
node_label = node._label.copy()
node_label.stylize(style)
if node._allow_expand:
prefix = ("" if node.expanded else "", base_style)
prefix = (
"" if node.is_expanded else "",
base_style + Style.from_meta({"toggle": True}),
)
else:
prefix = ("", base_style)
@@ -296,7 +310,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
"""Clear all nodes under root."""
self._tree_lines_cached = None
self._current_id = 0
root_label = self.root.label
root_label = self.root._label
root_data = self.root.data
self.root = TreeNode(
self,
@@ -309,7 +323,23 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._updates += 1
self.refresh()
def add_json(self, json_data: object) -> None:
def get_node_at_line(self, line_no: int) -> TreeNode[TreeDataType] | None:
"""Get the node for a given line.
Args:
line_no (int): A line number.
Returns:
TreeNode[TreeDataType] | None: A tree node, or ``None`` if there is no node at that line.
"""
try:
line = self._tree_lines[line_no]
except IndexError:
return None
else:
return line.node
def add_json(self, node: TreeNode, json_data: object) -> None:
from rich.highlighter import ReprHighlighter
@@ -317,12 +347,12 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
def add_node(name: str, node: TreeNode, data: object) -> None:
if isinstance(data, dict):
node.label = Text(f"{{}} {name}")
node._label = Text(f"{{}} {name}")
for key, value in data.items():
new_node = node.add("")
add_node(key, new_node, value)
elif isinstance(data, list):
node.label = Text(f"[] {name}")
node._label = Text(f"[] {name}")
for index, value in enumerate(data):
new_node = node.add("")
add_node(str(index), new_node, value)
@@ -334,10 +364,10 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
)
else:
label = Text(repr(data))
node.label = label
node._label = label
add_node("", self.root, json_data)
self.invalidate()
add_node(node._label, node, json_data)
# self.invalidate()
def validate_cursor_line(self, value: int) -> int:
return clamp(value, 0, len(self._tree_lines) - 1)
@@ -350,7 +380,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._tree_lines_cached = None
self._updates += 1
self.root._reset()
self.refresh()
self.refresh(layout=True)
def _on_mouse_move(self, event: events.MouseMove):
meta = event.style.meta
@@ -460,9 +490,14 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
self._tree_lines_cached = lines
guide_depth = self.guide_depth
width = max([line.get_line_width(guide_depth) for line in lines])
if lines:
width = max([line.get_line_width(guide_depth) for line in lines])
else:
width = self.size.width
self.virtual_size = Size(width, len(lines))
if self.cursor_line >= len(lines):
self.cursor_line = -1
def render_line(self, y: int) -> list[Segment]:
width = self.size.width
@@ -496,7 +531,7 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
y == self.hover_line,
y == self.cursor_line,
self.has_focus,
tuple((node._hover, node._selected, node.expanded) for node in line.path),
tuple((node._hover, node._selected, node._expanded) for node in line.path),
)
if cache_key in self._line_cache:
segments = self._line_cache[cache_key]
@@ -583,27 +618,31 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
guides.append(label)
segments = list(guides.render(self.app.console))
segments = line_pad(
segments, 0, self.virtual_size.width - guides.cell_len, line_style
)
pad_width = max(self.virtual_size.width, width)
segments = line_pad(segments, 0, pad_width - guides.cell_len, line_style)
self._line_cache[cache_key] = segments
segments = line_crop(segments, x1, x2, width)
return segments
def _on_resize(self) -> None:
def _on_resize(self, event: events.Resize) -> None:
self._line_cache.grow(event.size.height)
self.invalidate()
self._line_cache.grow(self.size.height * 2)
async def _on_click(self, event: events.Click) -> None:
meta = event.style.meta
if "line" in meta:
cursor_line = meta["line"]
if self.cursor_line == cursor_line:
await self.action("select_cursor")
if meta.get("toggle", False):
node = self.get_node_at_line(cursor_line)
if node is not None:
node.toggle()
else:
self.cursor_line = cursor_line
if self.cursor_line == cursor_line:
await self.action("select_cursor")
else:
self.cursor_line = cursor_line
def action_cursor_up(self) -> None:
if self.cursor_line == -1:
@@ -622,5 +661,5 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
else:
node = line.path[-1]
if self.auto_expand:
node.expanded = not node.expanded
node.toggle()
self.emit_no_wait(self.NodeSelected(self, line.path[-1]))