This commit is contained in:
Will McGugan
2022-09-07 15:11:34 +01:00
parent 07724489bf
commit 385a02a1e1
8 changed files with 173 additions and 76 deletions

View File

@@ -5,16 +5,10 @@ from textual.widgets import DirectoryTree
class TreeApp(App): class TreeApp(App):
DEFAULT_CSS = """
Screen {
overflow: auto;
}
"""
def compose(self): def compose(self):
tree = DirectoryTree("~/projects") tree = DirectoryTree("~/projects")
yield Container(tree) yield Container(tree)
tree.focus()
app = TreeApp() app = TreeApp()

View File

@@ -240,7 +240,7 @@ class App(Generic[ReturnType], DOMNode):
title: Reactive[str] = Reactive("Textual") title: Reactive[str] = Reactive("Textual")
sub_title: Reactive[str] = Reactive("") sub_title: Reactive[str] = Reactive("")
dark: Reactive[bool] = Reactive(False) dark: Reactive[bool] = Reactive(True)
@property @property
def devtools_enabled(self) -> bool: def devtools_enabled(self) -> bool:

View File

@@ -488,7 +488,8 @@ class DOMNode(MessagePump):
"""Get a Rich Style object for this DOMNode.""" """Get a Rich Style object for this DOMNode."""
_, _, background, color = self.colors _, _, background, color = self.colors
style = ( style = (
Style.from_color(color.rich_color, background.rich_color) + self.text_style Style.from_color((background + color).rich_color, background.rich_color)
+ self.text_style
) )
return style return style

View File

@@ -522,10 +522,12 @@ class MessagePump(metaclass=MessagePumpMeta):
Args: Args:
event (events.Key): A key event. event (events.Key): A key event.
""" """
key_method = getattr(self, f"key_{event.key_name}", None) key_method = getattr(self, f"key_{event.key_name}", None) or getattr(
self, f"_key_{event.key_name}", None
)
if key_method is not None: if key_method is not None:
if await invoke(key_method, event): await invoke(key_method, event)
event.prevent_default() event.prevent_default()
async def on_timer(self, event: events.Timer) -> None: async def on_timer(self, event: events.Timer) -> None:
event.prevent_default() event.prevent_default()

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
import rich.repr import rich.repr
from .geometry import Region
from ._types import CallbackType from ._types import CallbackType
from .message import Message from .message import Message
@@ -39,7 +40,7 @@ class Layout(Message, verbose=True):
@rich.repr.auto @rich.repr.auto
class InvokeLater(Message, verbose=True): class InvokeLater(Message, verbose=True, bubble=False):
def __init__(self, sender: MessagePump, callback: CallbackType) -> None: def __init__(self, sender: MessagePump, callback: CallbackType) -> None:
self.callback = callback self.callback = callback
super().__init__(sender) super().__init__(sender)
@@ -48,11 +49,10 @@ class InvokeLater(Message, verbose=True):
yield "callback", self.callback yield "callback", self.callback
# TODO: This should really be an Event
@rich.repr.auto @rich.repr.auto
class CursorMove(Message): class ScrollToRegion(Message, bubble=False):
def __init__(self, sender: MessagePump, line: int) -> None: def __init__(self, sender: MessagePump, region: Region) -> None:
self.line = line self.region = region
super().__init__(sender) super().__init__(sender)

View File

@@ -611,6 +611,18 @@ class Widget(DOMNode):
except errors.NoWidget: except errors.NoWidget:
return Region() return Region()
@property
def container_viewport(self) -> Region:
"""The viewport region (parent window)
Returns:
Region: The region that contains this widget.
"""
if self.parent is None:
return self.size.region
assert isinstance(self.parent, Widget)
return self.parent.region
@property @property
def virtual_region(self) -> Region: def virtual_region(self) -> Region:
"""The widget region relative to it's container. Which may not be visible, """The widget region relative to it's container. Which may not be visible,
@@ -1070,6 +1082,7 @@ class Widget(DOMNode):
window = self.content_region.at_offset(self.scroll_offset) window = self.content_region.at_offset(self.scroll_offset)
if spacing is not None: if spacing is not None:
window = window.shrink(spacing) window = window.shrink(spacing)
self.log(window=window, region=region)
delta_x, delta_y = Region.get_scroll_to_visible(window, region) delta_x, delta_y = Region.get_scroll_to_visible(window, region)
scroll_x, scroll_y = self.scroll_offset scroll_x, scroll_y = self.scroll_offset
delta = Offset( delta = Offset(
@@ -1080,7 +1093,7 @@ class Widget(DOMNode):
self.scroll_relative( self.scroll_relative(
delta.x or None, delta.x or None,
delta.y or None, delta.y or None,
animate=animate, animate=animate if abs(delta_y) > 1 else False,
duration=0.2, duration=0.2,
) )
return delta return delta
@@ -1093,13 +1106,21 @@ class Widget(DOMNode):
def __init_subclass__( def __init_subclass__(
cls, cls,
can_focus: bool = False, can_focus: bool | None = None,
can_focus_children: bool = True, can_focus_children: bool | None = None,
inherit_css: bool = True, inherit_css: bool = True,
) -> None: ) -> None:
base = cls.__mro__[0]
super().__init_subclass__(inherit_css=inherit_css) super().__init_subclass__(inherit_css=inherit_css)
cls.can_focus = can_focus if issubclass(base, Widget):
cls.can_focus_children = can_focus_children
cls.can_focus = base.can_focus if can_focus is None else can_focus
cls.can_focus_children = (
base.can_focus_children
if can_focus_children is None
else can_focus_children
)
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
yield "id", self.id, None yield "id", self.id, None
@@ -1529,49 +1550,52 @@ class Widget(DOMNode):
if self.has_focus: if self.has_focus:
self.app._reset_focus(self) self.app._reset_focus(self)
def key_home(self) -> bool: def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None:
self.scroll_to_region(message.region, animate=True)
def _key_home(self) -> bool:
if self._allow_scroll: if self._allow_scroll:
self.scroll_home() self.scroll_home()
return True return True
return False return False
def key_end(self) -> bool: def _key_end(self) -> bool:
if self._allow_scroll: if self._allow_scroll:
self.scroll_end() self.scroll_end()
return True return True
return False return False
def key_left(self) -> bool: def _key_left(self) -> bool:
if self.allow_horizontal_scroll: if self.allow_horizontal_scroll:
self.scroll_left() self.scroll_left()
return True return True
return False return False
def key_right(self) -> bool: def _key_right(self) -> bool:
if self.allow_horizontal_scroll: if self.allow_horizontal_scroll:
self.scroll_right() self.scroll_right()
return True return True
return False return False
def key_down(self) -> bool: def _key_down(self) -> bool:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
self.scroll_down() self.scroll_down()
return True return True
return False return False
def key_up(self) -> bool: def _key_up(self) -> bool:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
self.scroll_up() self.scroll_up()
return True return True
return False return False
def key_pagedown(self) -> bool: def _key_pagedown(self) -> bool:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
self.scroll_page_down() self.scroll_page_down()
return True return True
return False return False
def key_pageup(self) -> bool: def _key_pageup(self) -> bool:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
self.scroll_page_up() self.scroll_page_up()
return True return True

View File

@@ -44,21 +44,6 @@ class DirectoryTree(TreeControl[DirEntry]):
super().__init__(label, data, name=name, id=id, classes=classes) super().__init__(label, data, name=name, id=id, classes=classes)
self.root.tree.guide_style = "black" self.root.tree.guide_style = "black"
has_focus: Reactive[bool] = Reactive(False)
def on_focus(self) -> None:
self.has_focus = True
def on_blur(self) -> None:
self.has_focus = False
async def watch_hover_node(self, hover_node: NodeID) -> None:
for node in self.nodes.values():
node.tree.guide_style = (
"bold not dim red" if node.id == hover_node else "black"
)
self.refresh()
def render_node(self, node: TreeNode[DirEntry]) -> RenderableType: def render_node(self, node: TreeNode[DirEntry]) -> RenderableType:
return self.render_tree_label( return self.render_tree_label(
node, node,
@@ -99,13 +84,14 @@ class DirectoryTree(TreeControl[DirEntry]):
label.stylize("dim") label.stylize("dim")
if is_cursor and has_focus: if is_cursor and has_focus:
label.stylize("reverse") cursor_style = self.get_component_styles("tree--cursor").rich_style
label.stylize(cursor_style)
icon_label = Text(f"{icon} ", no_wrap=True, overflow="ellipsis") + label icon_label = Text(f"{icon} ", no_wrap=True, overflow="ellipsis") + label
icon_label.apply_meta(meta) icon_label.apply_meta(meta)
return icon_label return icon_label
async def on_mount(self, event: events.Mount) -> None: def on_mount(self) -> None:
self.call_later(self.load_directory, self.root) self.call_later(self.load_directory, self.root)
async def load_directory(self, node: TreeNode[DirEntry]): async def load_directory(self, node: TreeNode[DirEntry]):
@@ -113,22 +99,23 @@ class DirectoryTree(TreeControl[DirEntry]):
directory = sorted( directory = sorted(
list(scandir(path)), key=lambda entry: (not entry.is_dir(), entry.name) list(scandir(path)), key=lambda entry: (not entry.is_dir(), entry.name)
) )
self.log(directory)
for entry in directory: for entry in directory:
await node.add(entry.name, DirEntry(entry.path, entry.is_dir())) node.add(entry.name, DirEntry(entry.path, entry.is_dir()))
node.loaded = True node.loaded = True
await node.expand() node.expand()
self.refresh(layout=True) self.refresh(layout=True)
async def on_tree_click(self, message: TreeClick[DirEntry]) -> None: async def on_tree_control_node_selected(self, message: TreeClick[DirEntry]) -> None:
dir_entry = message.node.data dir_entry = message.node.data
if not dir_entry.is_dir: if not dir_entry.is_dir:
await self.emit(FileClick(self, dir_entry.path)) await self.emit(FileClick(self, dir_entry.path))
else: else:
if not message.node.loaded: if not message.node.loaded:
await self.load_directory(message.node) await self.load_directory(message.node)
await message.node.expand() message.node.expand()
else: else:
await message.node.toggle() message.node.toggle()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -8,12 +8,13 @@ from rich.console import RenderableType
from rich.text import Text, TextType from rich.text import Text, TextType
from rich.tree import Tree from rich.tree import Tree
from ..geometry import Region
from .. import events from .. import events
from ..reactive import Reactive from ..reactive import Reactive
from .._types import MessageTarget from .._types import MessageTarget
from ..widget import Widget from ..widget import Widget
from ..message import Message from ..message import Message
from ..messages import CursorMove from .. import messages
NodeID = NewType("NodeID", int) NodeID = NewType("NodeID", int)
@@ -141,16 +142,16 @@ class TreeNode(Generic[NodeDataType]):
sibling = node sibling = node
return None return None
async def expand(self, expanded: bool = True) -> None: def expand(self, expanded: bool = True) -> None:
self._expanded = expanded self._expanded = expanded
self._tree.expanded = expanded self._tree.expanded = expanded
self._control.refresh(layout=True) self._control.refresh(layout=True)
async def toggle(self) -> None: def toggle(self) -> None:
await self.expand(not self._expanded) self.expand(not self._expanded)
async def add(self, label: TextType, data: NodeDataType) -> None: def add(self, label: TextType, data: NodeDataType) -> None:
await self._control.add(self.id, label, data=data) self._control.add(self.id, label, data=data)
self._control.refresh(layout=True) self._control.refresh(layout=True)
self._empty = False self._empty = False
@@ -178,16 +179,37 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
} }
TreeControl > .tree--guides { TreeControl > .tree--guides {
color: $success;
}
TreeControl > .tree--guides-highlight {
color: $secondary; color: $secondary;
text-style: bold;
}
TreeControl > .tree--labels {
color: $text-panel;
}
TreeControl > .tree--cursor {
background: $secondary;
color: $text-secondary;
} }
""" """
COMPONENT_CLASSES: ClassVar[set[str]] = { COMPONENT_CLASSES: ClassVar[set[str]] = {
"tree--guides", "tree--guides",
"tree--guides-highlight",
"tree--labels", "tree--labels",
"tree--cursor",
} }
class NodeSelected(Message, bubble=False):
def __init__(self, sender: MessageTarget, node: TreeNode[NodeDataType]) -> None:
self.node = node
super().__init__(sender)
def __init__( def __init__(
self, self,
label: TextType, label: TextType,
@@ -202,6 +224,7 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
self.node_id = NodeID(0) self.node_id = NodeID(0)
self.nodes: dict[NodeID, TreeNode[NodeDataType]] = {} self.nodes: dict[NodeID, TreeNode[NodeDataType]] = {}
self._tree = Tree(label) self._tree = Tree(label)
self.root: TreeNode[NodeDataType] = TreeNode( self.root: TreeNode[NodeDataType] = TreeNode(
None, self.node_id, self, self._tree, label, data None, self.node_id, self, self._tree, label, data
) )
@@ -216,21 +239,43 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
show_cursor: Reactive[bool] = Reactive(False) show_cursor: Reactive[bool] = Reactive(False)
def watch_show_cursor(self, value: bool) -> None: def watch_show_cursor(self, value: bool) -> None:
self.emit_no_wait(CursorMove(self, self.cursor_line)) line_region = Region(0, self.cursor_line, self.size.width, 1)
self.emit_no_wait(messages.ScrollToRegion(self, line_region))
def watch_cursor_line(self, value: int) -> None: def watch_cursor_line(self, value: int) -> None:
if self.show_cursor: line_region = Region(0, value, self.size.width, 1)
self.emit_no_wait(CursorMove(self, value + self.gutter.top)) self.emit_no_wait(messages.ScrollToRegion(self, line_region))
async def add( def watch_hover_node(self, previous_hover_node: NodeID, hover_node: NodeID) -> None:
previous_hover = self.nodes.get(previous_hover_node)
if previous_hover is not None:
previous_hover._tree.guide_style = self._guide_style
hover = self.nodes.get(hover_node)
if hover is not None:
hover._tree.guide_style = self._highlight_guide_style
self.refresh()
def watch_cursor(self, previous_cursor_node: NodeID, cursor_node: NodeID) -> None:
previous_cursor = self.nodes.get(previous_cursor_node)
if previous_cursor is not None:
previous_cursor._tree.guide_style = self._guide_style
cursor = self.nodes.get(cursor_node)
if cursor is not None:
cursor._tree.guide_style = self._highlight_guide_style
self.refresh()
def add(
self, self,
node_id: NodeID, node_id: NodeID,
label: TextType, label: TextType,
data: NodeDataType, data: NodeDataType,
) -> None: ) -> None:
parent = self.nodes[node_id] parent = self.nodes[node_id]
self.node_id = NodeID(self.node_id + 1) self.node_id = NodeID(self.node_id + 1)
child_tree = parent._tree.add(label) child_tree = parent._tree.add(label)
child_tree.guide_style = self._guide_style
child_node: TreeNode[NodeDataType] = TreeNode( child_node: TreeNode[NodeDataType] = TreeNode(
parent, self.node_id, self, child_tree, label, data parent, self.node_id, self, child_tree, label, data
) )
@@ -267,12 +312,12 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
return None return None
def render(self) -> RenderableType: def render(self) -> RenderableType:
self._tree.guide_style = self._component_styles["tree--guides"].node.rich_style
return self._tree return self._tree
def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType: def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType:
label_style = self.get_component_styles("tree--labels").rich_style
label = ( label = (
Text(node.label, no_wrap=True, overflow="ellipsis") Text(node.label, no_wrap=True, style=label_style, overflow="ellipsis")
if isinstance(node.label, str) if isinstance(node.label, str)
else node.label else node.label
) )
@@ -281,33 +326,77 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
label.apply_meta({"@click": f"click_label({node.id})", "tree_node": node.id}) label.apply_meta({"@click": f"click_label({node.id})", "tree_node": node.id})
return label return label
async def action_click_label(self, node_id: NodeID) -> None: def action_click_label(self, node_id: NodeID) -> None:
node = self.nodes[node_id] node = self.nodes[node_id]
self.cursor = node.id self.cursor = node.id
self.cursor_line = self.find_cursor() or 0 self.cursor_line = self.find_cursor() or 0
self.show_cursor = False self.show_cursor = True
await self.post_message(TreeClick(self, node)) self.post_message_no_wait(self.NodeSelected(self, node))
async def on_mouse_move(self, event: events.MouseMove) -> None: def on_mount(self) -> None:
self._guide_style = self.get_component_styles("tree--guides").rich_style
self._highlight_guide_style = self.get_component_styles(
"tree--guides-highlight"
).rich_style
self._tree.guide_style = self._guide_style
def on_mouse_move(self, event: events.MouseMove) -> None:
self.hover_node = event.style.meta.get("tree_node") self.hover_node = event.style.meta.get("tree_node")
async def on_key(self, event: events.Key) -> None: async def on_key(self, event: events.Key) -> None:
await self.dispatch_key(event) await self.dispatch_key(event)
async def key_down(self, event: events.Key) -> None: def key_down(self, event: events.Key) -> None:
event.stop() event.stop()
await self.cursor_down() self.cursor_down()
async def key_up(self, event: events.Key) -> None: def key_up(self, event: events.Key) -> None:
event.stop() event.stop()
await self.cursor_up() self.cursor_up()
async def key_enter(self, event: events.Key) -> None: def key_pagedown(self) -> None:
assert self.parent is not None
height = self.container_viewport.height
cursor = self.cursor
cursor_line = self.cursor_line
for _ in range(height):
cursor_node = self.nodes[cursor]
next_node = cursor_node.next_node
if next_node is not None:
cursor_line += 1
cursor = next_node.id
self.cursor = cursor
self.cursor_line = cursor_line
def key_pageup(self) -> None:
assert self.parent is not None
height = self.container_viewport.height
cursor = self.cursor
cursor_line = self.cursor_line
for _ in range(height):
cursor_node = self.nodes[cursor]
previous_node = cursor_node.previous_node
if previous_node is not None:
cursor_line -= 1
cursor = previous_node.id
self.cursor = cursor
self.cursor_line = cursor_line
def key_home(self) -> None:
self.cursor_line = 0
self.cursor = NodeID(0)
def key_end(self) -> None:
self.cursor = self.nodes[NodeID(0)].children[-1].id
self.cursor_line = self.find_cursor() or 0
def key_enter(self, event: events.Key) -> None:
cursor_node = self.nodes[self.cursor] cursor_node = self.nodes[self.cursor]
event.stop() event.stop()
await self.post_message(TreeClick(self, cursor_node)) self.post_message_no_wait(self.NodeSelected(self, cursor_node))
async def cursor_down(self) -> None: def cursor_down(self) -> None:
if not self.show_cursor: if not self.show_cursor:
self.show_cursor = True self.show_cursor = True
return return
@@ -317,7 +406,7 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
self.cursor_line += 1 self.cursor_line += 1
self.cursor = next_node.id self.cursor = next_node.id
async def cursor_up(self) -> None: def cursor_up(self) -> None:
if not self.show_cursor: if not self.show_cursor:
self.show_cursor = True self.show_cursor = True
return return