real breadth first, and tests

This commit is contained in:
Will McGugan
2022-10-13 20:51:03 +01:00
parent 5a8e492294
commit 7eb5119fe0
4 changed files with 89 additions and 45 deletions

View File

@@ -37,7 +37,7 @@ from .css.stylesheet import Stylesheet
from .design import ColorSystem from .design import ColorSystem
from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog
from .devtools.redirect_output import StdoutRedirector from .devtools.redirect_output import StdoutRedirector
from .dom import DOMNode from .dom import DOMNode, NoScreen
from .driver import Driver from .driver import Driver
from .drivers.headless_driver import HeadlessDriver from .drivers.headless_driver import HeadlessDriver
from .features import FeatureFlag, parse_features from .features import FeatureFlag, parse_features
@@ -1142,7 +1142,10 @@ class App(Generic[ReturnType], DOMNode):
Args: Args:
widget (Widget): A Widget to unregister widget (Widget): A Widget to unregister
""" """
widget.screen._reset_focus(widget) try:
widget.screen._reset_focus(widget)
except NoScreen:
pass
if isinstance(widget._parent, Widget): if isinstance(widget._parent, Widget):
widget._parent.children._remove(widget) widget._parent.children._remove(widget)
@@ -1394,8 +1397,8 @@ class App(Generic[ReturnType], DOMNode):
if parent is not None: if parent is not None:
parent.refresh(layout=True) parent.refresh(layout=True)
remove_widgets = list( remove_widgets = widget.walk_children(
widget.walk_children(Widget, with_self=True, method="depth") Widget, with_self=True, method="depth", reverse=True
) )
for child in remove_widgets: for child in remove_widgets:
self._unregister(child) self._unregister(child)

View File

@@ -573,7 +573,7 @@ class Styles(StylesBase):
if self.node is not None: if self.node is not None:
self.node.refresh(layout=layout) self.node.refresh(layout=layout)
if children: if children:
for child in self.node.walk_children(with_self=False): for child in self.node.walk_children(with_self=False, reverse=True):
child.refresh(layout=layout) child.refresh(layout=layout)
def reset(self) -> None: def reset(self) -> None:

View File

@@ -632,13 +632,18 @@ class DOMNode(MessagePump):
filter_type: type[WalkType], filter_type: type[WalkType],
*, *,
with_self: bool = True, with_self: bool = True,
method: WalkMethod = "breadth", method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[WalkType]: ) -> Iterable[WalkType]:
... ...
@overload @overload
def walk_children( def walk_children(
self, *, with_self: bool = True, method: WalkMethod = "breadth" self,
*,
with_self: bool = True,
method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[DOMNode]: ) -> Iterable[DOMNode]:
... ...
@@ -647,7 +652,8 @@ class DOMNode(MessagePump):
filter_type: type[WalkType] | None = None, filter_type: type[WalkType] | None = None,
*, *,
with_self: bool = True, with_self: bool = True,
method: WalkMethod = "breadth", method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[DOMNode | WalkType]: ) -> Iterable[DOMNode | WalkType]:
"""Generate descendant nodes. """Generate descendant nodes.
@@ -655,14 +661,15 @@ class DOMNode(MessagePump):
filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter. filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter.
Defaults to None. Defaults to None.
with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True.
method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "breadth". method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth".
reverse (bool, optional): Reverse the order (bottom up). Defaults to False
Returns: Returns:
Iterable[DOMNode | WalkType]: An iterable of nodes. Iterable[DOMNode | WalkType]: An iterable of nodes.
""" """
def walk_breadth_first() -> Iterable[DOMNode]: def walk_depth_first() -> Iterable[DOMNode]:
"""Walk the tree breadth first (parent's first).""" """Walk the tree breadth first (parent's first)."""
stack: list[Iterator[DOMNode]] = [iter(self.children)] stack: list[Iterator[DOMNode]] = [iter(self.children)]
pop = stack.pop pop = stack.pop
@@ -682,31 +689,29 @@ class DOMNode(MessagePump):
if node.children: if node.children:
push(iter(node.children)) push(iter(node.children))
def walk_depth_first() -> Iterable[DOMNode]: def walk_breadth_first() -> Iterable[DOMNode]:
"""Walk the tree depth first (children first).""" """Walk the tree depth first (children first)."""
depth_stack: list[tuple[DOMNode, Iterator[DOMNode]]] = ( queue: deque[DOMNode] = deque()
[(self, iter(self.children))] popleft = queue.popleft
if with_self extend = queue.extend
else [(node, iter(node.children)) for node in reversed(self.children)]
)
pop = depth_stack.pop
push = depth_stack.append
check_type = filter_type or DOMNode
while depth_stack: if with_self:
node, iter_nodes = pop() yield self
child_widget = next(iter_nodes, None) queue.extend(self.children)
if child_widget is None:
if isinstance(node, check_type):
yield node
else:
push((node, iter_nodes))
push((child_widget, iter(child_widget.children)))
if method == "depth": while queue:
yield from walk_depth_first() node = popleft()
yield node
extend(node.children)
node_generator = (
walk_depth_first() if method == "depth" else walk_breadth_first()
)
if reverse:
yield from reversed(list(node_generator))
else: else:
yield from walk_breadth_first() yield from node_generator
def get_child(self, id: str) -> DOMNode: def get_child(self, id: str) -> DOMNode:
"""Return the first child (immediate descendent) of this node with the given ID. """Return the first child (immediate descendent) of this node with the given ID.

View File

@@ -77,27 +77,63 @@ def test_validate():
node.toggle_class("1") node.toggle_class("1")
def test_walk_children(parent): @pytest.fixture
children = [node.id for node in parent.walk_children(with_self=False)] def search():
assert children == ["child1", "grandchild1", "child2"] """
a
/ \
b c
/ / \
d e f
"""
a = DOMNode(id="a")
b = DOMNode(id="b")
c = DOMNode(id="c")
d = DOMNode(id="d")
e = DOMNode(id="e")
f = DOMNode(id="f")
a._add_child(b)
a._add_child(c)
b._add_child(d)
c._add_child(e)
c._add_child(f)
yield a
def test_walk_children_with_self(parent): def test_walk_children_depth(search):
children = [node.id for node in parent.walk_children(with_self=True)]
assert children == ["parent", "child1", "grandchild1", "child2"]
def test_walk_children_depth(parent):
children = [ children = [
node.id for node in parent.walk_children(with_self=False, method="depth") node.id for node in search.walk_children(method="depth", with_self=False)
]
assert children == ["b", "d", "c", "e", "f"]
def test_walk_children_with_self_depth(search):
children = [
node.id for node in search.walk_children(method="depth", with_self=True)
]
assert children == ["a", "b", "d", "c", "e", "f"]
def test_walk_children_breadth(search):
children = [
node.id for node in search.walk_children(with_self=False, method="breadth")
] ]
print(children) print(children)
assert children == ["grandchild1", "child1", "child2"] assert children == ["b", "c", "d", "e", "f"]
def test_walk_children_with_self_depth(parent): def test_walk_children_with_self_breadth(search):
children = [ children = [
node.id for node in parent.walk_children(with_self=True, method="depth") node.id for node in search.walk_children(with_self=True, method="breadth")
] ]
print(children) print(children)
assert children == ["grandchild1", "child1", "child2", "parent"] assert children == ["a", "b", "c", "d", "e", "f"]
children = [
node.id
for node in search.walk_children(with_self=True, method="breadth", reverse=True)
]
assert children == ["f", "e", "d", "c", "b", "a"]