Query dom - get child

This commit is contained in:
Darren Burns
2022-02-22 11:19:00 +00:00
parent d0b1ca5283
commit 154ada427f
6 changed files with 76 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
from textual.app import App
from textual import events from textual import events
from textual.widgets import Placeholder from textual.app import App
from textual.widget import Widget from textual.widget import Widget
from textual.widgets import Placeholder
class BasicApp(App): class BasicApp(App):
@@ -20,14 +20,17 @@ class BasicApp(App):
await self.dispatch_key(event) await self.dispatch_key(event)
def key_a(self) -> None: def key_a(self) -> None:
self.query("#footer").set_styles(text="on magenta") footer = self.get_child("#footer")
footer.set_styles(text="on magenta")
def key_b(self) -> None: def key_b(self) -> None:
self["#footer"].set_styles("text: on green") footer = self.get_child("#footer")
footer.set_styles("text: on green")
def key_c(self) -> None: def key_c(self) -> None:
self["#header"].toggle_class("-highlight") header = self.get_child("#header")
self.log(self["#header"].styles) header.toggle_class("-highlight")
self.log(header.styles)
BasicApp.run(css_file="local_styles.css", log="textual.log") BasicApp.run(css_file="local_styles.css", log="textual.log")

View File

@@ -36,7 +36,7 @@ from .reactive import Reactive
from .view import View from .view import View
from .widget import Widget from .widget import Widget
from .css.query import EmptyQueryError from .css.query import NoMatchingNodesError
if TYPE_CHECKING: if TYPE_CHECKING:
from .css.query import DOMQuery from .css.query import DOMQuery
@@ -279,13 +279,18 @@ class App(DOMNode):
return DOMQuery(self.view, selector) return DOMQuery(self.view, selector)
def __getitem__(self, selector: str) -> DOMNode: def get_child(self, selector: str) -> DOMNode:
from .css.query import DOMQuery """Shorthand for self.view.get_child(selector: str)
Returns the first child (immediate descendent) of this DOMNode
matching the selector.
try: Args:
return DOMQuery(self.view, selector).first() selector (str): A CSS selector.
except EmptyQueryError:
raise KeyError(selector) Returns:
DOMNode: The first child of this node which matches the selector.
"""
return self.view.get_child(selector)
def update_styles(self) -> None: def update_styles(self) -> None:
"""Request update of styles. """Request update of styles.

View File

@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..dom import DOMNode from ..dom import DOMNode
class EmptyQueryError(Exception): class NoMatchingNodesError(Exception):
pass pass
@@ -38,7 +38,7 @@ class DOMQuery:
selector: str | None = None, selector: str | None = None,
nodes: list[DOMNode] | None = None, nodes: list[DOMNode] | None = None,
) -> None: ) -> None:
self._selector = selector
self._nodes: list[DOMNode] = [] self._nodes: list[DOMNode] = []
if nodes is not None: if nodes is not None:
self._nodes = nodes self._nodes = nodes
@@ -103,7 +103,9 @@ class DOMQuery:
if self._nodes: if self._nodes:
return self._nodes[0] return self._nodes[0]
else: else:
raise EmptyQueryError("Query is empty") raise NoMatchingNodesError(
f"No nodes match the selector {self._selector!r}"
)
def add_class(self, *class_names: str) -> DOMQuery: def add_class(self, *class_names: str) -> DOMQuery:
"""Add the given class name(s) to nodes.""" """Add the given class name(s) to nodes."""

View File

@@ -282,6 +282,20 @@ class DOMNode(MessagePump):
if node.children: if node.children:
push(iter(node.children)) push(iter(node.children))
def get_child(self, selector: str) -> DOMNode:
"""Return the first child (immediate descendent) of this DOMNode matching a selector.
Args:
selector (str): A CSS selector.
Returns:
DOMNode: The first child of this node which matches the selector.
"""
from .css.query import DOMQuery
query = DOMQuery(selector=selector, nodes=list(self.children))
return query.first()
def query(self, selector: str | None = None) -> DOMQuery: def query(self, selector: str | None = None) -> DOMQuery:
"""Get a DOM query. """Get a DOM query.

View File

@@ -69,12 +69,6 @@ class View(Widget):
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
yield "name", self.name yield "name", self.name
def __getitem__(self, widget_id: str) -> Widget:
try:
return self.get_child_by_id(widget_id)
except errors.MissingWidget as error:
raise KeyError(str(error))
@property @property
def is_visual(self) -> bool: def is_visual(self) -> bool:
return False return False

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from textual.css.errors import StyleValueError from textual.css.errors import StyleValueError
from textual.css.query import NoMatchingNodesError
from textual.dom import DOMNode from textual.dom import DOMNode
@@ -23,3 +24,38 @@ def test_display_set_invalid_value():
node = DOMNode() node = DOMNode()
with pytest.raises(StyleValueError): with pytest.raises(StyleValueError):
node.display = "blah" node.display = "blah"
@pytest.fixture
def parent():
parent = DOMNode(id="parent")
child1 = DOMNode(id="child1")
child1.add_class("foo")
child2 = DOMNode(id="child2")
child2.add_class("bar")
grandchild1 = DOMNode(id="grandchild1")
child1.add_child(grandchild1)
parent.add_child(child1)
parent.add_child(child2)
yield parent
def test_get_child_gets_first_child(parent):
child = parent.get_child(".foo")
assert child.id == "child1"
assert child.get_child("#grandchild1").id == "grandchild1"
assert parent.get_child(".bar").id == "child2"
def test_get_child_no_matching_child(parent):
with pytest.raises(NoMatchingNodesError):
parent.get_child("#doesnt-exist")
def test_get_child_only_immediate_descendents(parent):
with pytest.raises(NoMatchingNodesError):
parent.get_child("#grandchild1")