From 7eb5119fe0a6e05e7478d14fa3a96e630acd6a9b Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 13 Oct 2022 20:51:03 +0100 Subject: [PATCH] real breadth first, and tests --- src/textual/app.py | 11 ++++--- src/textual/css/styles.py | 2 +- src/textual/dom.py | 57 ++++++++++++++++++---------------- tests/test_dom.py | 64 ++++++++++++++++++++++++++++++--------- 4 files changed, 89 insertions(+), 45 deletions(-) diff --git a/src/textual/app.py b/src/textual/app.py index 0b3789623..7cf010f3c 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -37,7 +37,7 @@ from .css.stylesheet import Stylesheet from .design import ColorSystem from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog from .devtools.redirect_output import StdoutRedirector -from .dom import DOMNode +from .dom import DOMNode, NoScreen from .driver import Driver from .drivers.headless_driver import HeadlessDriver from .features import FeatureFlag, parse_features @@ -1142,7 +1142,10 @@ class App(Generic[ReturnType], DOMNode): Args: widget (Widget): A Widget to unregister """ - widget.screen._reset_focus(widget) + try: + widget.screen._reset_focus(widget) + except NoScreen: + pass if isinstance(widget._parent, Widget): widget._parent.children._remove(widget) @@ -1394,8 +1397,8 @@ class App(Generic[ReturnType], DOMNode): if parent is not None: parent.refresh(layout=True) - remove_widgets = list( - widget.walk_children(Widget, with_self=True, method="depth") + remove_widgets = widget.walk_children( + Widget, with_self=True, method="depth", reverse=True ) for child in remove_widgets: self._unregister(child) diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index d59bea2f6..dd4b8d4b6 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -573,7 +573,7 @@ class Styles(StylesBase): if self.node is not None: self.node.refresh(layout=layout) if children: - for child in self.node.walk_children(with_self=False): + for child in self.node.walk_children(with_self=False, reverse=True): child.refresh(layout=layout) def reset(self) -> None: diff --git a/src/textual/dom.py b/src/textual/dom.py index 1d3712e38..caea2c1f0 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -632,13 +632,18 @@ class DOMNode(MessagePump): filter_type: type[WalkType], *, with_self: bool = True, - method: WalkMethod = "breadth", + method: WalkMethod = "depth", + reverse: bool = False, ) -> Iterable[WalkType]: ... @overload def walk_children( - self, *, with_self: bool = True, method: WalkMethod = "breadth" + self, + *, + with_self: bool = True, + method: WalkMethod = "depth", + reverse: bool = False, ) -> Iterable[DOMNode]: ... @@ -647,7 +652,8 @@ class DOMNode(MessagePump): filter_type: type[WalkType] | None = None, *, with_self: bool = True, - method: WalkMethod = "breadth", + method: WalkMethod = "depth", + reverse: bool = False, ) -> Iterable[DOMNode | WalkType]: """Generate descendant nodes. @@ -655,14 +661,15 @@ class DOMNode(MessagePump): filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter. Defaults to None. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. - method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "breadth". + method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth". + reverse (bool, optional): Reverse the order (bottom up). Defaults to False Returns: Iterable[DOMNode | WalkType]: An iterable of nodes. """ - def walk_breadth_first() -> Iterable[DOMNode]: + def walk_depth_first() -> Iterable[DOMNode]: """Walk the tree breadth first (parent's first).""" stack: list[Iterator[DOMNode]] = [iter(self.children)] pop = stack.pop @@ -682,31 +689,29 @@ class DOMNode(MessagePump): if node.children: push(iter(node.children)) - def walk_depth_first() -> Iterable[DOMNode]: + def walk_breadth_first() -> Iterable[DOMNode]: """Walk the tree depth first (children first).""" - depth_stack: list[tuple[DOMNode, Iterator[DOMNode]]] = ( - [(self, iter(self.children))] - if with_self - else [(node, iter(node.children)) for node in reversed(self.children)] - ) - pop = depth_stack.pop - push = depth_stack.append - check_type = filter_type or DOMNode + queue: deque[DOMNode] = deque() + popleft = queue.popleft + extend = queue.extend - while depth_stack: - node, iter_nodes = pop() - child_widget = next(iter_nodes, None) - if child_widget is None: - if isinstance(node, check_type): - yield node - else: - push((node, iter_nodes)) - push((child_widget, iter(child_widget.children))) + if with_self: + yield self + queue.extend(self.children) - if method == "depth": - yield from walk_depth_first() + while queue: + node = popleft() + yield node + extend(node.children) + + node_generator = ( + walk_depth_first() if method == "depth" else walk_breadth_first() + ) + + if reverse: + yield from reversed(list(node_generator)) else: - yield from walk_breadth_first() + yield from node_generator def get_child(self, id: str) -> DOMNode: """Return the first child (immediate descendent) of this node with the given ID. diff --git a/tests/test_dom.py b/tests/test_dom.py index 3a35768d3..5a713193a 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -77,27 +77,63 @@ def test_validate(): node.toggle_class("1") -def test_walk_children(parent): - children = [node.id for node in parent.walk_children(with_self=False)] - assert children == ["child1", "grandchild1", "child2"] +@pytest.fixture +def search(): + """ + a + / \ + b c + / / \ + d e f + """ + a = DOMNode(id="a") + b = DOMNode(id="b") + c = DOMNode(id="c") + d = DOMNode(id="d") + e = DOMNode(id="e") + f = DOMNode(id="f") + + a._add_child(b) + a._add_child(c) + b._add_child(d) + c._add_child(e) + c._add_child(f) + + yield a -def test_walk_children_with_self(parent): - children = [node.id for node in parent.walk_children(with_self=True)] - assert children == ["parent", "child1", "grandchild1", "child2"] - - -def test_walk_children_depth(parent): +def test_walk_children_depth(search): children = [ - node.id for node in parent.walk_children(with_self=False, method="depth") + node.id for node in search.walk_children(method="depth", with_self=False) + ] + assert children == ["b", "d", "c", "e", "f"] + + +def test_walk_children_with_self_depth(search): + children = [ + node.id for node in search.walk_children(method="depth", with_self=True) + ] + assert children == ["a", "b", "d", "c", "e", "f"] + + +def test_walk_children_breadth(search): + children = [ + node.id for node in search.walk_children(with_self=False, method="breadth") ] print(children) - assert children == ["grandchild1", "child1", "child2"] + assert children == ["b", "c", "d", "e", "f"] -def test_walk_children_with_self_depth(parent): +def test_walk_children_with_self_breadth(search): children = [ - node.id for node in parent.walk_children(with_self=True, method="depth") + node.id for node in search.walk_children(with_self=True, method="breadth") ] print(children) - assert children == ["grandchild1", "child1", "child2", "parent"] + assert children == ["a", "b", "c", "d", "e", "f"] + + children = [ + node.id + for node in search.walk_children(with_self=True, method="breadth", reverse=True) + ] + + assert children == ["f", "e", "d", "c", "b", "a"]