From 6ec37ce82f153822264fa3fdcf9f5986bde14d73 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Sat, 21 Aug 2021 11:19:06 +0100 Subject: [PATCH] add keys to tree control --- examples/code_viewer.py | 2 +- src/textual/_ansi_sequences.py | 5 +- src/textual/_xterm_parser.py | 1 - src/textual/events.py | 1 - src/textual/geometry.py | 8 +- src/textual/keys.py | 9 +- src/textual/messages.py | 7 ++ src/textual/reactive.py | 1 + src/textual/widget.py | 22 ++-- src/textual/widgets/_directory_tree.py | 70 +++++++++++- src/textual/widgets/_scroll_view.py | 16 ++- src/textual/widgets/_tree_control.py | 146 ++++++++++++++++++++++--- tests/test_geometry.py | 5 + 13 files changed, 248 insertions(+), 45 deletions(-) diff --git a/examples/code_viewer.py b/examples/code_viewer.py index b5418c428..5b044ca42 100644 --- a/examples/code_viewer.py +++ b/examples/code_viewer.py @@ -41,7 +41,7 @@ class MyApp(App): # Note the directory is also in a scroll view await self.view.dock( - ScrollView(self.directory), edge="left", size=32, name="sidebar" + ScrollView(self.directory), edge="left", size=64, name="sidebar" ) await self.view.dock(self.body, edge="top") diff --git a/src/textual/_ansi_sequences.py b/src/textual/_ansi_sequences.py index af3f2cae4..e24dd1d01 100644 --- a/src/textual/_ansi_sequences.py +++ b/src/textual/_ansi_sequences.py @@ -1,10 +1,11 @@ -from typing import Dict, Tuple, Union +from typing import Dict, Tuple from .keys import Keys # Mapping of vt100 escape codes to Keys. ANSI_SEQUENCES: Dict[str, Tuple[Keys, ...]] = { # Control keys. + "\r": (Keys.Enter,), "\x00": (Keys.ControlAt,), # Control-At (Also for Ctrl-Space) "\x01": (Keys.ControlA,), # Control-A (home) "\x02": (Keys.ControlB,), # Control-B (emacs cursor left) @@ -18,7 +19,7 @@ ANSI_SEQUENCES: Dict[str, Tuple[Keys, ...]] = { "\x0a": (Keys.ControlJ,), # Control-J (10) (Identical to '\n') "\x0b": (Keys.ControlK,), # Control-K (delete until end of line; vertical tab) "\x0c": (Keys.ControlL,), # Control-L (clear; form feed) - "\x0d": (Keys.ControlM,), # Control-M (13) (Identical to '\r') + # "\x0d": (Keys.ControlM,), # Control-M (13) (Identical to '\r') "\x0e": (Keys.ControlN,), # Control-N (14) (history forward) "\x0f": (Keys.ControlO,), # Control-O (15) "\x10": (Keys.ControlP,), # Control-P (16) (history back) diff --git a/src/textual/_xterm_parser.py b/src/textual/_xterm_parser.py index b6b2dea7e..7a2194e97 100644 --- a/src/textual/_xterm_parser.py +++ b/src/textual/_xterm_parser.py @@ -91,7 +91,6 @@ class XTermParser(Parser[events.Event]): on_token(event) break else: - keys = get_ansi_sequence(character, None) if keys is not None: for key in keys: diff --git a/src/textual/events.py b/src/textual/events.py index bb7cc6431..22424d691 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -1,6 +1,5 @@ from __future__ import annotations -from asyncio import Event from typing import Awaitable, Callable, Type, TYPE_CHECKING, TypeVar import rich.repr diff --git a/src/textual/geometry.py b/src/textual/geometry.py index 651b60faa..ff124280c 100644 --- a/src/textual/geometry.py +++ b/src/textual/geometry.py @@ -131,10 +131,10 @@ class Size(NamedTuple): class Region(NamedTuple): """Defines a rectangular region.""" - x: int - y: int - width: int - height: int + x: int = 0 + y: int = 0 + width: int = 0 + height: int = 0 @classmethod def from_corners(cls, x1: int, y1: int, x2: int, y2: int) -> Region: diff --git a/src/textual/keys.py b/src/textual/keys.py index 662ac2fe3..49ea46c1c 100644 --- a/src/textual/keys.py +++ b/src/textual/keys.py @@ -16,6 +16,7 @@ class Keys(str, Enum): Escape = "escape" # Also Control-[ ShiftEscape = "shift+escape" + Return = "return" ControlAt = "ctrl+@" # Also Control-Space. @@ -186,10 +187,10 @@ class Keys(str, Enum): Ignore = "" # Some 'Key' aliases (for backwardshift+compatibility). - ControlSpace = ControlAt - Tab = ControlI - Enter = ControlM - Backspace = ControlH + ControlSpace = "ctrl-at" + Tab = "tab" + Enter = "enter" + Backspace = "backspace" # ShiftControl was renamed to ControlShift in # 888fcb6fa4efea0de8333177e1bbc792f3ff3c24 (20 Feb 2020). diff --git a/src/textual/messages.py b/src/textual/messages.py index 6263fa4e6..1261ecca3 100644 --- a/src/textual/messages.py +++ b/src/textual/messages.py @@ -30,3 +30,10 @@ class UpdateMessage(Message, verbosity=3): class LayoutMessage(Message, verbosity=3): def can_replace(self, message: Message) -> bool: return isinstance(message, LayoutMessage) + + +@rich.repr.auto +class CursorMoveMessage(Message, bubble=True): + def __init__(self, sender: MessagePump, line: int) -> None: + self.line = line + super().__init__(sender) diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 6777af5c6..4ef0818ae 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -70,6 +70,7 @@ class Reactive(Generic[ReactiveType]): value = validate_function(value) if current_value != value or self._first: + self._first = False setattr(obj, self.internal_name, value) self.check_watchers(obj, name, current_value) diff --git a/src/textual/widget.py b/src/textual/widget.py index 866b86493..19d6eca4b 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -24,6 +24,7 @@ from rich.text import TextType from . import events from ._animator import BoundAnimator +from ._callback import invoke from ._context import active_app from .geometry import Size from .message import Message @@ -43,6 +44,14 @@ class RenderCache(NamedTuple): size: Size lines: Lines + @property + def cursor_line(self) -> int | None: + for index, line in enumerate(self.lines): + for text, style, control in line: + if style and style._meta and style.meta.get("cursor", False): + return index + return None + @rich.repr.auto class Widget(MessagePump): @@ -164,23 +173,18 @@ class Widget(MessagePump): def _update_size(self, size: Size) -> None: self._size = size - def render_lines(self) -> RenderCache: + def render_lines(self) -> None: width, height = self.size renderable = self.render_styled() options = self.console.options.update_dimensions(width, height) lines = self.console.render_lines(renderable, options) self.render_cache = RenderCache(self.size, lines) - return self.render_cache - - def render_lines_free(self, width: int) -> RenderCache: + def render_lines_free(self, width: int) -> None: renderable = self.render_styled() - options = self.console.options.update(width=width, height=None) - lines = self.console.render_lines(renderable, options) self.render_cache = RenderCache(Size(width, len(lines)), lines) - return self.render_cache def _get_lines(self) -> Lines: """Get render lines for given dimensions. @@ -193,7 +197,7 @@ class Widget(MessagePump): Lines: [description] """ if self.render_cache is None: - self.render_cache = self.render_lines() + self.render_lines() lines = self.render_cache.lines return lines @@ -297,7 +301,7 @@ class Widget(MessagePump): key_method = getattr(self, f"key_{event.key}", None) if key_method is not None: - await key_method() + await invoke(key_method, event) async def on_mouse_down(self, event: events.MouseUp) -> None: await self.broker_event("mouse.down", event) diff --git a/src/textual/widgets/_directory_tree.py b/src/textual/widgets/_directory_tree.py index 650c4f1db..ef43e6eaf 100644 --- a/src/textual/widgets/_directory_tree.py +++ b/src/textual/widgets/_directory_tree.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from functools import lru_cache from os import scandir import os.path @@ -11,6 +12,7 @@ from rich.tree import Tree from .. import events from ..message import Message +from ..reactive import Reactive from .._types import MessageTarget from . import TreeControl, TreeClick, TreeNode, NodeID @@ -36,6 +38,14 @@ class DirectoryTree(TreeControl[DirEntry]): super().__init__(label, name=name, data=data) self.root.tree.guide_style = "black" + has_focus: Reactive[bool] = Reactive(False) + + def on_focus(self) -> None: + self.has_focus = True + + def on_blur(self) -> None: + self.has_focus = False + async def watch_hover_node(self, hover_node: NodeID) -> None: for node in self.nodes.values(): node.tree.guide_style = ( @@ -44,13 +54,62 @@ class DirectoryTree(TreeControl[DirEntry]): self.refresh(layout=True) def render_node(self, node: TreeNode[DirEntry]) -> RenderableType: - meta = {"@click": f"click_label({node.id})", "tree_node": node.id} + # TODO: Optimize / cache this + return self.render_tree_label( + node, + node.data.is_dir, + node.expanded, + node.is_cursor, + node.id == self.hover_node, + self.has_focus, + ) + # meta = { + # "@click": f"click_label({node.id})", + # "tree_node": node.id, + # "cursor": node.is_cursor, + # } + # label = Text(node.label) if isinstance(node.label, str) else node.label + # if node.id == self.hover_node: + # label.stylize("underline") + # if node.data.is_dir: + # label.stylize("bold magenta") + # icon = "📂" if node.expanded else "📁" + # else: + # label.stylize("bright_green") + # icon = "📄" + # label.highlight_regex(r"\..*$", "green") + + # if label.plain.startswith("."): + # label.stylize("dim") + + # if node.is_cursor and self.has_focus: + # label.stylize("reverse") + + # icon_label = Text(f"{icon} ", no_wrap=True, overflow="ellipsis") + label + # icon_label.apply_meta(meta) + # return icon_label + + @lru_cache(maxsize=1024 * 32) + def render_tree_label( + self, + node: TreeNode[DirEntry], + is_dir: bool, + expanded: bool, + is_cursor: bool, + is_hover: bool, + has_focus: bool, + ) -> RenderableType: + meta = { + "@click": f"click_label({node.id})", + "tree_node": node.id, + "cursor": node.is_cursor, + } label = Text(node.label) if isinstance(node.label, str) else node.label - if node.id == self.hover_node: + if is_hover: label.stylize("underline") - if node.data.is_dir: + if is_dir: label.stylize("bold magenta") - icon = "📂" if node.expanded else "📁" + icon = "📂" if expanded else "📁" else: label.stylize("bright_green") icon = "📄" @@ -59,6 +118,9 @@ class DirectoryTree(TreeControl[DirEntry]): if label.plain.startswith("."): label.stylize("dim") + if is_cursor and has_focus: + label.stylize("reverse") + icon_label = Text(f"{icon} ", no_wrap=True, overflow="ellipsis") + label icon_label.apply_meta(meta) return icon_label diff --git a/src/textual/widgets/_scroll_view.py b/src/textual/widgets/_scroll_view.py index dcee85e64..1e2d80e4a 100644 --- a/src/textual/widgets/_scroll_view.py +++ b/src/textual/widgets/_scroll_view.py @@ -1,5 +1,4 @@ from __future__ import annotations -from logging import PlaceHolder from rich.console import RenderableType from rich.style import StyleType @@ -8,11 +7,9 @@ from rich.style import StyleType from .. import events from ..layouts.grid import GridLayout from ..message import Message -from ..messages import UpdateMessage +from ..messages import CursorMoveMessage from ..scrollbar import ScrollTo, ScrollBar -from ..geometry import clamp, Offset, Size -from ..page import Page -from ..reactive import watch +from ..geometry import clamp from ..view import View from ..widget import Widget @@ -121,6 +118,12 @@ class ScrollView(View): self.target_x += self.size.width self.animate("x", self.target_x, speed=120, easing="out_cubic") + def scroll_in_to_view(self, line: int) -> None: + if line < self.y: + self.target_y = line + elif line > self.y + self.size.height: + self.target_y = line - self.size.height + async def on_mouse_scroll_up(self, event: events.MouseScrollUp) -> None: self.scroll_up() @@ -191,3 +194,6 @@ class ScrollView(View): self.refresh() if self.layout.show_row("hscroll", virtual_size.width > self.size.width): self.refresh() + + async def message_cursor_move(self, message: CursorMoveMessage) -> None: + self.scroll_in_to_view(message.line) diff --git a/src/textual/widgets/_tree_control.py b/src/textual/widgets/_tree_control.py index 5fadc033a..687a39b76 100644 --- a/src/textual/widgets/_tree_control.py +++ b/src/textual/widgets/_tree_control.py @@ -1,20 +1,21 @@ from __future__ import annotations -from typing import Any, Generic, NewType, TypeVar +from functools import lru_cache -from rich.console import Console, ConsoleOptions, RenderableType +from typing import Generic, NewType, TypeVar -from rich.style import Style, StyleType -from rich.styled import Styled +from rich.console import RenderableType from rich.text import Text, TextType from rich.tree import Tree -from rich.padding import Padding, PaddingDimensions +from rich.padding import PaddingDimensions from .. import log +from .. import events from ..reactive import Reactive from .._types import MessageTarget from ..widget import Widget from ..message import Message +from ..messages import CursorMoveMessage NodeID = NewType("NodeID", int) @@ -26,12 +27,14 @@ NodeDataType = TypeVar("NodeDataType") class TreeNode(Generic[NodeDataType]): def __init__( self, + parent: TreeNode[NodeDataType] | None, node_id: NodeID, control: TreeControl, tree: Tree, label: TextType, data: NodeDataType, ) -> None: + self.parent = parent self._node_id = node_id self._control = control self._tree = tree @@ -41,6 +44,7 @@ class TreeNode(Generic[NodeDataType]): self._expanded = False self._empty = False self._tree.expanded = False + self.children: list[TreeNode] = [] @property def id(self) -> NodeID: @@ -58,14 +62,88 @@ class TreeNode(Generic[NodeDataType]): def expanded(self) -> bool: return self._expanded + @property + def is_cursor(self) -> bool: + return self.control.cursor == self.id + @property def tree(self) -> Tree: return self._tree + @property + def next_node(self) -> TreeNode[NodeDataType] | None: + """The next node in the tree, or None if at the end.""" + + if self.expanded and self.children: + return self.children[0] + else: + + sibling = self.next_sibling + if sibling is not None: + return sibling + + node = self + while True: + if node.parent is None: + return None + sibling = node.parent.next_sibling + if sibling is not None: + return sibling + else: + node = node.parent + + @property + def previous_node(self) -> TreeNode[NodeDataType] | None: + """The previous node in the tree, or None if at the end.""" + + sibling = self.previous_sibling + if sibling is not None: + + def last_sibling(node) -> TreeNode[NodeDataType]: + if node.expanded and node.children: + return last_sibling(node.children[-1]) + else: + return ( + node.children[-1] if (node.children and node.expanded) else node + ) + + return last_sibling(sibling) + + if self.parent is None: + return None + return self.parent + + @property + def next_sibling(self) -> TreeNode[NodeDataType] | None: + """The next sibling, or None if last sibling.""" + if self.parent is None: + return None + iter_siblings = iter(self.parent.children) + try: + for node in iter_siblings: + if node is self: + return next(iter_siblings) + except StopIteration: + return None + + @property + def previous_sibling(self) -> TreeNode[NodeDataType] | None: + """Previous sibling or None if first sibling.""" + if self.parent is None: + return None + iter_siblings = iter(self.parent.children) + sibling: TreeNode[NodeDataType] | None = None + + for node in iter_siblings: + if node is self: + return sibling + sibling = node + return None + async def expand(self, expanded: bool = True) -> None: self._expanded = expanded self._tree.expanded = expanded - self._control.refresh() + self._control.refresh(layout=True) async def toggle(self) -> None: await self.expand(not self._expanded) @@ -100,14 +178,17 @@ class TreeControl(Generic[NodeDataType], Widget): self.nodes: dict[NodeID, TreeNode[NodeDataType]] = {} self._tree = Tree(label) self.root: TreeNode[NodeDataType] = TreeNode( - self._node_id, self, self._tree, label, data + None, self._node_id, self, self._tree, label, data ) + self._tree.label = self.root self.nodes[NodeID(self._node_id)] = self.root super().__init__(name=name) self.padding = padding + self.cursor = self.root.id hover_node: Reactive[NodeID | None] = Reactive(None) + cursor: Reactive[NodeID] = Reactive(NodeID(0), layout=True) async def add( self, @@ -119,33 +200,70 @@ class TreeControl(Generic[NodeDataType], Widget): self._node_id = NodeID(self._node_id + 1) child_tree = parent._tree.add(label) child_node: TreeNode[NodeDataType] = TreeNode( - self._node_id, self, child_tree, label, data + parent, self._node_id, self, child_tree, label, data ) + parent.children.append(child_node) child_tree.label = child_node self.nodes[self._node_id] = child_node - self.refresh() + self.refresh(layout=True) def render(self) -> RenderableType: return self._tree def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType: - meta = {"@click": f"click_label({node.id})", "tree_node": node.id} - label = Text(node.label) if isinstance(node.label, str) else node.label + label = ( + Text(node.label, no_wrap=True, overflow="ellipsis") + if isinstance(node.label, str) + else node.label + ) if node.id == self.hover_node: label.stylize("underline") - label.apply_meta(meta) - label.no_wrap = True - label.overflow = "ellipsis" + label.apply_meta( + { + "@click": f"click_label({node.id})", + "tree_node": node.id, + "cursor": node.is_cursor, + } + ) return label async def action_click_label(self, node_id: NodeID) -> None: node = self.nodes[node_id] + self.cursor = node.id await self.post_message(TreeClick(self, node)) async def on_mouse_move(self, event: events.MouseMove) -> None: self.hover_node = event.style.meta.get("tree_node") + async def on_key(self, event: events.Key) -> None: + await self.dispatch_key(event) + + async def key_down(self, event: events.Key) -> None: + await self.cursor_down() + event.stop() + + async def key_up(self, event: events.Key) -> None: + await self.cursor_up() + event.stop() + + async def key_enter(self, event: events.Key) -> None: + cursor_node = self.nodes[self.cursor] + event.stop() + await self.post_message(TreeClick(self, cursor_node)) + + async def cursor_down(self) -> None: + cursor_node = self.nodes[self.cursor] + next_node = cursor_node.next_node + if next_node is not None: + self.hover_node = self.cursor = next_node.id + + async def cursor_up(self) -> None: + cursor_node = self.nodes[self.cursor] + previous_node = cursor_node.previous_node + if previous_node is not None: + self.hover_node = self.cursor = previous_node.id + if __name__ == "__main__": diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 71c16317b..40fd1e564 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -87,6 +87,11 @@ def test_point_blend(): assert Offset(1, 2).blend(Offset(3, 4), 0.5) == Offset(2, 3) +def test_region_null(): + assert Region() == Region(0, 0, 0, 0) + assert not Region() + + def test_region_from_origin(): assert Region.from_origin(Offset(3, 4), (5, 6)) == Region(3, 4, 5, 6)