Merge pull request #744 from Textualize/tree-fix

tree fix
This commit is contained in:
Will McGugan
2022-09-07 19:45:16 +01:00
committed by GitHub
9 changed files with 209 additions and 97 deletions

View File

@@ -30,7 +30,7 @@ App > Screen {
overflow-y: auto;
height: 20;
margin: 1 2;
background: $panel;
background: $surface;
padding: 1 2;
}

View File

@@ -5,16 +5,10 @@ from textual.widgets import DirectoryTree
class TreeApp(App):
DEFAULT_CSS = """
Screen {
overflow: auto;
}
"""
def compose(self):
tree = DirectoryTree("~/projects")
yield Container(tree)
tree.focus()
app = TreeApp()

View File

@@ -899,11 +899,12 @@ class App(Generic[ReturnType], DOMNode):
self.log.system(f"{self.screen} is active")
return previous_screen
def set_focus(self, widget: Widget | None) -> None:
def set_focus(self, widget: Widget | None, scroll_visible: bool = False) -> None:
"""Focus (or unfocus) a widget. A focused widget will receive key events first.
Args:
widget (Widget): [description]
widget (Widget): Widget to focus.
scroll_visible (bool, optional): Scroll widget in to view.
"""
if widget == self.focused:
# Widget is already focused
@@ -924,7 +925,8 @@ class App(Generic[ReturnType], DOMNode):
# Change focus
self.focused = widget
# Send focus event
self.screen.scroll_to_widget(widget)
if scroll_visible:
self.screen.scroll_to_widget(widget)
widget.post_message_no_wait(events.Focus(self))
widget.emit_no_wait(events.DescendantFocus(self))

View File

@@ -488,7 +488,8 @@ class DOMNode(MessagePump):
"""Get a Rich Style object for this DOMNode."""
_, _, background, color = self.colors
style = (
Style.from_color(color.rich_color, background.rich_color) + self.text_style
Style.from_color((background + color).rich_color, background.rich_color)
+ self.text_style
)
return style

View File

@@ -522,10 +522,12 @@ class MessagePump(metaclass=MessagePumpMeta):
Args:
event (events.Key): A key event.
"""
key_method = getattr(self, f"key_{event.key_name}", None)
key_method = getattr(self, f"key_{event.key_name}", None) or getattr(
self, f"_key_{event.key_name}", None
)
if key_method is not None:
if await invoke(key_method, event):
event.prevent_default()
await invoke(key_method, event)
event.prevent_default()
async def on_timer(self, event: events.Timer) -> None:
event.prevent_default()

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
import rich.repr
from .geometry import Region
from ._types import CallbackType
from .message import Message
@@ -39,7 +40,7 @@ class Layout(Message, verbose=True):
@rich.repr.auto
class InvokeLater(Message, verbose=True):
class InvokeLater(Message, verbose=True, bubble=False):
def __init__(self, sender: MessagePump, callback: CallbackType) -> None:
self.callback = callback
super().__init__(sender)
@@ -48,11 +49,12 @@ class InvokeLater(Message, verbose=True):
yield "callback", self.callback
# TODO: This should really be an Event
@rich.repr.auto
class CursorMove(Message):
def __init__(self, sender: MessagePump, line: int) -> None:
self.line = line
class ScrollToRegion(Message, bubble=False):
"""Ask the parent to scroll a given region in to view."""
def __init__(self, sender: MessagePump, region: Region) -> None:
self.region = region
super().__init__(sender)

View File

@@ -4,7 +4,6 @@ from asyncio import Lock
from fractions import Fraction
from itertools import islice
from operator import attrgetter
from types import GeneratorType
from typing import TYPE_CHECKING, ClassVar, Collection, Iterable, NamedTuple
import rich.repr
@@ -611,6 +610,18 @@ class Widget(DOMNode):
except errors.NoWidget:
return Region()
@property
def container_viewport(self) -> Region:
"""The viewport region (parent window)
Returns:
Region: The region that contains this widget.
"""
if self.parent is None:
return self.size.region
assert isinstance(self.parent, Widget)
return self.parent.region
@property
def virtual_region(self) -> Region:
"""The widget region relative to it's container. Which may not be visible,
@@ -1080,7 +1091,7 @@ class Widget(DOMNode):
self.scroll_relative(
delta.x or None,
delta.y or None,
animate=animate,
animate=animate if (abs(delta_y) > 1 or delta_x) else False,
duration=0.2,
)
return delta
@@ -1093,13 +1104,19 @@ class Widget(DOMNode):
def __init_subclass__(
cls,
can_focus: bool = False,
can_focus_children: bool = True,
can_focus: bool | None = None,
can_focus_children: bool | None = None,
inherit_css: bool = True,
) -> None:
base = cls.__mro__[0]
super().__init_subclass__(inherit_css=inherit_css)
cls.can_focus = can_focus
cls.can_focus_children = can_focus_children
if issubclass(base, Widget):
cls.can_focus = base.can_focus if can_focus is None else can_focus
cls.can_focus_children = (
base.can_focus_children
if can_focus_children is None
else can_focus_children
)
def __rich_repr__(self) -> rich.repr.Result:
yield "id", self.id, None
@@ -1529,49 +1546,52 @@ class Widget(DOMNode):
if self.has_focus:
self.app._reset_focus(self)
def key_home(self) -> bool:
def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None:
self.scroll_to_region(message.region, animate=True)
def _key_home(self) -> bool:
if self._allow_scroll:
self.scroll_home()
return True
return False
def key_end(self) -> bool:
def _key_end(self) -> bool:
if self._allow_scroll:
self.scroll_end()
return True
return False
def key_left(self) -> bool:
def _key_left(self) -> bool:
if self.allow_horizontal_scroll:
self.scroll_left()
return True
return False
def key_right(self) -> bool:
def _key_right(self) -> bool:
if self.allow_horizontal_scroll:
self.scroll_right()
return True
return False
def key_down(self) -> bool:
def _key_down(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_down()
return True
return False
def key_up(self) -> bool:
def _key_up(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_up()
return True
return False
def key_pagedown(self) -> bool:
def _key_pagedown(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_page_down()
return True
return False
def key_pageup(self) -> bool:
def _key_pageup(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_page_up()
return True

View File

@@ -11,9 +11,8 @@ from rich.text import Text
from .. import events
from ..message import Message
from ..reactive import Reactive
from .._types import MessageTarget
from ._tree_control import TreeControl, TreeClick, TreeNode, NodeID
from ._tree_control import TreeControl, TreeNode
@dataclass
@@ -44,21 +43,6 @@ class DirectoryTree(TreeControl[DirEntry]):
super().__init__(label, data, name=name, id=id, classes=classes)
self.root.tree.guide_style = "black"
has_focus: Reactive[bool] = Reactive(False)
def on_focus(self) -> None:
self.has_focus = True
def on_blur(self) -> None:
self.has_focus = False
async def watch_hover_node(self, hover_node: NodeID) -> None:
for node in self.nodes.values():
node.tree.guide_style = (
"bold not dim red" if node.id == hover_node else "black"
)
self.refresh()
def render_node(self, node: TreeNode[DirEntry]) -> RenderableType:
return self.render_tree_label(
node,
@@ -99,13 +83,17 @@ class DirectoryTree(TreeControl[DirEntry]):
label.stylize("dim")
if is_cursor and has_focus:
label.stylize("reverse")
cursor_style = self.get_component_styles("tree--cursor").rich_style
label.stylize(cursor_style)
icon_label = Text(f"{icon} ", no_wrap=True, overflow="ellipsis") + label
icon_label.apply_meta(meta)
return icon_label
async def on_mount(self, event: events.Mount) -> None:
def on_styles_updated(self) -> None:
self.render_tree_label.cache_clear()
def on_mount(self) -> None:
self.call_later(self.load_directory, self.root)
async def load_directory(self, node: TreeNode[DirEntry]):
@@ -114,21 +102,23 @@ class DirectoryTree(TreeControl[DirEntry]):
list(scandir(path)), key=lambda entry: (not entry.is_dir(), entry.name)
)
for entry in directory:
await node.add(entry.name, DirEntry(entry.path, entry.is_dir()))
node.add(entry.name, DirEntry(entry.path, entry.is_dir()))
node.loaded = True
await node.expand()
node.expand()
self.refresh(layout=True)
async def on_tree_click(self, message: TreeClick[DirEntry]) -> None:
async def on_tree_control_node_selected(
self, message: TreeControl.NodeSelected[DirEntry]
) -> None:
dir_entry = message.node.data
if not dir_entry.is_dir:
await self.emit(FileClick(self, dir_entry.path))
else:
if not message.node.loaded:
await self.load_directory(message.node)
await message.node.expand()
message.node.expand()
else:
await message.node.toggle()
message.node.toggle()
if __name__ == "__main__":

View File

@@ -5,21 +5,24 @@ from typing import ClassVar, Generic, Iterator, NewType, TypeVar
import rich.repr
from rich.console import RenderableType
from rich.style import Style
from rich.text import Text, TextType
from rich.tree import Tree
from ..geometry import Region, Size
from .. import events
from ..reactive import Reactive
from .._types import MessageTarget
from ..widget import Widget
from ..message import Message
from ..messages import CursorMove
from .. import messages
NodeID = NewType("NodeID", int)
NodeDataType = TypeVar("NodeDataType")
EventNodeDataType = TypeVar("EventNodeDataType")
@rich.repr.auto
@@ -141,16 +144,16 @@ class TreeNode(Generic[NodeDataType]):
sibling = node
return None
async def expand(self, expanded: bool = True) -> None:
def expand(self, expanded: bool = True) -> None:
self._expanded = expanded
self._tree.expanded = expanded
self._control.refresh(layout=True)
async def toggle(self) -> None:
await self.expand(not self._expanded)
def toggle(self) -> None:
self.expand(not self._expanded)
async def add(self, label: TextType, data: NodeDataType) -> None:
await self._control.add(self.id, label, data=data)
def add(self, label: TextType, data: NodeDataType) -> None:
self._control.add(self.id, label, data=data)
self._control.refresh(layout=True)
self._empty = False
@@ -158,36 +161,56 @@ class TreeNode(Generic[NodeDataType]):
return self._control.render_node(self)
@rich.repr.auto
class TreeClick(Generic[NodeDataType], Message, bubble=True):
def __init__(self, sender: MessageTarget, node: TreeNode[NodeDataType]) -> None:
self.node = node
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "node", self.node
class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
DEFAULT_CSS = """
TreeControl {
background: $panel;
color: $text-panel;
background: $surface;
color: $text-surface;
height: auto;
width: 100%;
}
TreeControl > .tree--guides {
color: $secondary;
color: $success;
}
TreeControl > .tree--guides-highlight {
color: $success;
text-style: uu;
}
TreeControl > .tree--guides-cursor {
color: $secondary;
text-style: bold;
}
TreeControl > .tree--labels {
color: $text-panel;
}
TreeControl > .tree--cursor {
background: $secondary;
color: $text-secondary;
}
"""
COMPONENT_CLASSES: ClassVar[set[str]] = {
"tree--guides",
"tree--guides-highlight",
"tree--guides-cursor",
"tree--labels",
"tree--cursor",
}
class NodeSelected(Generic[EventNodeDataType], Message, bubble=False):
def __init__(
self, sender: MessageTarget, node: TreeNode[EventNodeDataType]
) -> None:
self.node = node
super().__init__(sender)
def __init__(
self,
label: TextType,
@@ -202,6 +225,7 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
self.node_id = NodeID(0)
self.nodes: dict[NodeID, TreeNode[NodeDataType]] = {}
self._tree = Tree(label)
self.root: TreeNode[NodeDataType] = TreeNode(
None, self.node_id, self, self._tree, label, data
)
@@ -215,22 +239,30 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
cursor_line: Reactive[int] = Reactive(0)
show_cursor: Reactive[bool] = Reactive(False)
def watch_show_cursor(self, value: bool) -> None:
self.emit_no_wait(CursorMove(self, self.cursor_line))
def watch_cursor_line(self, value: int) -> None:
if self.show_cursor:
self.emit_no_wait(CursorMove(self, value + self.gutter.top))
line_region = Region(0, value, self.size.width, 1)
self.emit_no_wait(messages.ScrollToRegion(self, line_region))
async def add(
def get_content_height(self, container: Size, viewport: Size, width: int) -> int:
def get_size(tree: Tree) -> int:
return 1 + sum(
get_size(child) if child.expanded else 1 for child in tree.children
)
size = get_size(self._tree)
return size
def add(
self,
node_id: NodeID,
label: TextType,
data: NodeDataType,
) -> None:
parent = self.nodes[node_id]
self.node_id = NodeID(self.node_id + 1)
child_tree = parent._tree.add(label)
child_tree.guide_style = self._guide_style
child_node: TreeNode[NodeDataType] = TreeNode(
parent, self.node_id, self, child_tree, label, data
)
@@ -267,12 +299,29 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
return None
def render(self) -> RenderableType:
self._tree.guide_style = self._component_styles["tree--guides"].node.rich_style
guide_style = self._guide_style
def update_guide_style(tree: Tree) -> None:
tree.guide_style = guide_style
for child in tree.children:
if child.expanded:
update_guide_style(child)
update_guide_style(self._tree)
if self.hover_node is not None:
hover = self.nodes.get(self.hover_node)
if hover is not None:
hover._tree.guide_style = self._highlight_guide_style
if self.cursor is not None and self.show_cursor:
cursor = self.nodes.get(self.cursor)
if cursor is not None:
cursor._tree.guide_style = self._cursor_guide_style
return self._tree
def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType:
label_style = self.get_component_styles("tree--labels").rich_style
label = (
Text(node.label, no_wrap=True, overflow="ellipsis")
Text(node.label, no_wrap=True, style=label_style, overflow="ellipsis")
if isinstance(node.label, str)
else node.label
)
@@ -281,33 +330,85 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
label.apply_meta({"@click": f"click_label({node.id})", "tree_node": node.id})
return label
async def action_click_label(self, node_id: NodeID) -> None:
def action_click_label(self, node_id: NodeID) -> None:
node = self.nodes[node_id]
self.cursor = node.id
self.cursor_line = self.find_cursor() or 0
self.show_cursor = False
await self.post_message(TreeClick(self, node))
self.show_cursor = True
self.post_message_no_wait(self.NodeSelected(self, node))
async def on_mouse_move(self, event: events.MouseMove) -> None:
def on_mount(self) -> None:
self._tree.guide_style = self._guide_style
@property
def _guide_style(self) -> Style:
return self.get_component_styles("tree--guides").rich_style
@property
def _highlight_guide_style(self) -> Style:
return self.get_component_styles("tree--guides-highlight").rich_style
@property
def _cursor_guide_style(self) -> Style:
return self.get_component_styles("tree--guides-cursor").rich_style
def on_mouse_move(self, event: events.MouseMove) -> None:
self.hover_node = event.style.meta.get("tree_node")
async def on_key(self, event: events.Key) -> None:
await self.dispatch_key(event)
async def key_down(self, event: events.Key) -> None:
def key_down(self, event: events.Key) -> None:
event.stop()
await self.cursor_down()
self.cursor_down()
async def key_up(self, event: events.Key) -> None:
def key_up(self, event: events.Key) -> None:
event.stop()
await self.cursor_up()
self.cursor_up()
async def key_enter(self, event: events.Key) -> None:
def key_pagedown(self) -> None:
assert self.parent is not None
height = self.container_viewport.height
cursor = self.cursor
cursor_line = self.cursor_line
for _ in range(height):
cursor_node = self.nodes[cursor]
next_node = cursor_node.next_node
if next_node is not None:
cursor_line += 1
cursor = next_node.id
self.cursor = cursor
self.cursor_line = cursor_line
def key_pageup(self) -> None:
assert self.parent is not None
height = self.container_viewport.height
cursor = self.cursor
cursor_line = self.cursor_line
for _ in range(height):
cursor_node = self.nodes[cursor]
previous_node = cursor_node.previous_node
if previous_node is not None:
cursor_line -= 1
cursor = previous_node.id
self.cursor = cursor
self.cursor_line = cursor_line
def key_home(self) -> None:
self.cursor_line = 0
self.cursor = NodeID(0)
def key_end(self) -> None:
self.cursor = self.nodes[NodeID(0)].children[-1].id
self.cursor_line = self.find_cursor() or 0
def key_enter(self, event: events.Key) -> None:
cursor_node = self.nodes[self.cursor]
event.stop()
await self.post_message(TreeClick(self, cursor_node))
self.post_message_no_wait(self.NodeSelected(self, cursor_node))
async def cursor_down(self) -> None:
def cursor_down(self) -> None:
if not self.show_cursor:
self.show_cursor = True
return
@@ -317,7 +418,7 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
self.cursor_line += 1
self.cursor = next_node.id
async def cursor_up(self) -> None:
def cursor_up(self) -> None:
if not self.show_cursor:
self.show_cursor = True
return