added results method

This commit is contained in:
Will McGugan
2022-08-13 12:25:35 +01:00
parent 77747bac56
commit 4ac4f3cee6
2 changed files with 35 additions and 2 deletions

View File

@@ -224,6 +224,33 @@ class DOMQuery:
else: else:
raise NoMatchingNodesError(f"No nodes match {self!r}") raise NoMatchingNodesError(f"No nodes match {self!r}")
@overload
def results(self) -> Iterator[Widget]:
...
@overload
def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]:
...
def results(
self, filter_type: type[ExpectType] | None = None
) -> Iterator[Widget | ExpectType]:
"""Get query results, optionally filtered by a given type.
Args:
filter_type (type[ExpectType] | None): A Widget class to filter results,
or None for no filter. Defaults to None.
Yields:
Iterator[Widget | ExpectType]: An iterator of Widget instances.
"""
if filter_type is None:
yield from self
else:
for node in self:
if isinstance(node, filter_type):
yield node
def add_class(self, *class_names: str) -> DOMQuery: def add_class(self, *class_names: str) -> DOMQuery:
"""Add the given class name(s) to nodes.""" """Add the given class name(s) to nodes."""
for node in self: for node in self:

View File

@@ -1,5 +1,3 @@
from textual.dom import DOMNode
from textual.widget import Widget from textual.widget import Widget
@@ -51,7 +49,15 @@ def test_query():
assert list(app.query("View#main")) == [main_view] assert list(app.query("View#main")) == [main_view]
assert list(app.query("#widget1")) == [widget1] assert list(app.query("#widget1")) == [widget1]
assert list(app.query("#widget2")) == [widget2] assert list(app.query("#widget2")) == [widget2]
assert list(app.query("Widget.float")) == [sidebar, tooltip, helpbar] assert list(app.query("Widget.float")) == [sidebar, tooltip, helpbar]
assert list(app.query("Widget.float").results(Widget)) == [
sidebar,
tooltip,
helpbar,
]
assert list(app.query("Widget.float").results(View)) == []
assert list(app.query("Widget.float.transient")) == [tooltip] assert list(app.query("Widget.float.transient")) == [tooltip]
assert list(app.query("App > View")) == [main_view, help_view] assert list(app.query("App > View")) == [main_view, help_view]