From 5a8e492294458b4a2f3992c9e795d13a97164bb5 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 13 Oct 2022 16:43:59 +0100 Subject: [PATCH] depth first search --- sandbox/will/screens_focus.css | 9 ++++ sandbox/will/screens_focus.py | 20 +++++++++ src/textual/app.py | 4 +- src/textual/dom.py | 80 +++++++++++++++++++++++++++------- src/textual/widget.py | 4 -- tests/test_dom.py | 26 +++++++++++ 6 files changed, 122 insertions(+), 21 deletions(-) create mode 100644 sandbox/will/screens_focus.css create mode 100644 sandbox/will/screens_focus.py diff --git a/sandbox/will/screens_focus.css b/sandbox/will/screens_focus.css new file mode 100644 index 000000000..dd2a3ab26 --- /dev/null +++ b/sandbox/will/screens_focus.css @@ -0,0 +1,9 @@ + Focusable { + padding: 3 6; + background: blue 20%; + } + + Focusable :focus { + border: solid red; + } + diff --git a/sandbox/will/screens_focus.py b/sandbox/will/screens_focus.py new file mode 100644 index 000000000..2d35f5470 --- /dev/null +++ b/sandbox/will/screens_focus.py @@ -0,0 +1,20 @@ +from textual.app import App, ComposeResult +from textual.widgets import Static, Footer + + +class Focusable(Static, can_focus=True): + pass + + +class ScreensFocusApp(App): + def compose(self) -> ComposeResult: + yield Focusable("App - one") + yield Focusable("App - two") + yield Focusable("App - three") + yield Focusable("App - four") + yield Footer() + + +app = ScreensFocusApp(css_path="screens_focus.css") +if __name__ == "__main__": + app.run() diff --git a/src/textual/app.py b/src/textual/app.py index d286dcffe..0b3789623 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -1394,7 +1394,9 @@ class App(Generic[ReturnType], DOMNode): if parent is not None: parent.refresh(layout=True) - remove_widgets = list(widget.walk_children(Widget, with_self=True)) + remove_widgets = list( + widget.walk_children(Widget, with_self=True, method="depth") + ) for child in remove_widgets: self._unregister(child) for child in remove_widgets: diff --git a/src/textual/dom.py b/src/textual/dom.py index f3234ac86..1d3712e38 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -1,7 +1,9 @@ from __future__ import annotations +from collections import deque from inspect import getfile import re +import sys from typing import ( cast, ClassVar, @@ -40,10 +42,23 @@ if TYPE_CHECKING: from .screen import Screen from .widget import Widget +if sys.version_info >= (3, 8): + from typing import Literal, Iterable, Sequence +else: + from typing_extensions import Literal + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: # pragma: no cover + from typing_extensions import TypeAlias + _re_identifier = re.compile(IDENTIFIER) +WalkMethod: TypeAlias = Literal["depth", "breadth"] + + class BadIdentifier(Exception): """raised by check_identifiers.""" @@ -617,11 +632,14 @@ class DOMNode(MessagePump): filter_type: type[WalkType], *, with_self: bool = True, + method: WalkMethod = "breadth", ) -> Iterable[WalkType]: ... @overload - def walk_children(self, *, with_self: bool = True) -> Iterable[DOMNode]: + def walk_children( + self, *, with_self: bool = True, method: WalkMethod = "breadth" + ) -> Iterable[DOMNode]: ... def walk_children( @@ -629,6 +647,7 @@ class DOMNode(MessagePump): filter_type: type[WalkType] | None = None, *, with_self: bool = True, + method: WalkMethod = "breadth", ) -> Iterable[DOMNode | WalkType]: """Generate descendant nodes. @@ -636,29 +655,58 @@ class DOMNode(MessagePump): filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter. Defaults to None. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. + method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "breadth". Returns: Iterable[DOMNode | WalkType]: An iterable of nodes. """ - stack: list[Iterator[DOMNode]] = [iter(self.children)] - pop = stack.pop - push = stack.append - check_type = filter_type or DOMNode + def walk_breadth_first() -> Iterable[DOMNode]: + """Walk the tree breadth first (parent's first).""" + stack: list[Iterator[DOMNode]] = [iter(self.children)] + pop = stack.pop + push = stack.append + check_type = filter_type or DOMNode - if with_self and isinstance(self, check_type): - yield self + if with_self and isinstance(self, check_type): + yield self - while stack: - node = next(stack[-1], None) - if node is None: - pop() - else: - if isinstance(node, check_type): - yield node - if node.children: - push(iter(node.children)) + while stack: + node = next(stack[-1], None) + if node is None: + pop() + else: + if isinstance(node, check_type): + yield node + if node.children: + push(iter(node.children)) + + def walk_depth_first() -> Iterable[DOMNode]: + """Walk the tree depth first (children first).""" + depth_stack: list[tuple[DOMNode, Iterator[DOMNode]]] = ( + [(self, iter(self.children))] + if with_self + else [(node, iter(node.children)) for node in reversed(self.children)] + ) + pop = depth_stack.pop + push = depth_stack.append + check_type = filter_type or DOMNode + + while depth_stack: + node, iter_nodes = pop() + child_widget = next(iter_nodes, None) + if child_widget is None: + if isinstance(node, check_type): + yield node + else: + push((node, iter_nodes)) + push((child_widget, iter(child_widget.children))) + + if method == "depth": + yield from walk_depth_first() + else: + yield from walk_breadth_first() def get_child(self, id: str) -> DOMNode: """Return the first child (immediate descendent) of this node with the given ID. diff --git a/src/textual/widget.py b/src/textual/widget.py index 6f190ec09..23858b64e 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -1857,10 +1857,6 @@ class Widget(DOMNode): await self.action(binding.action) return True - def _on_compose(self, event: events.Compose) -> None: - widgets = self.compose() - self.app.mount_all(widgets) - def _on_mount(self, event: events.Mount) -> None: widgets = self.compose() self.mount(*widgets) diff --git a/tests/test_dom.py b/tests/test_dom.py index e4254f6e5..3a35768d3 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -75,3 +75,29 @@ def test_validate(): node.remove_class("1") with pytest.raises(BadIdentifier): node.toggle_class("1") + + +def test_walk_children(parent): + children = [node.id for node in parent.walk_children(with_self=False)] + assert children == ["child1", "grandchild1", "child2"] + + +def test_walk_children_with_self(parent): + children = [node.id for node in parent.walk_children(with_self=True)] + assert children == ["parent", "child1", "grandchild1", "child2"] + + +def test_walk_children_depth(parent): + children = [ + node.id for node in parent.walk_children(with_self=False, method="depth") + ] + print(children) + assert children == ["grandchild1", "child1", "child2"] + + +def test_walk_children_with_self_depth(parent): + children = [ + node.id for node in parent.walk_children(with_self=True, method="depth") + ] + print(children) + assert children == ["grandchild1", "child1", "child2", "parent"]