From 2b304cc9f79a9a3835f34bc356055bec5497fdd9 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 22 Feb 2022 11:35:01 +0000 Subject: [PATCH] Get child by ID --- src/textual/app.py | 12 ++++++------ src/textual/dom.py | 19 ++++++++++--------- tests/test_animator.py | 6 +++--- tests/test_dom.py | 14 +++++--------- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/textual/app.py b/src/textual/app.py index b23e111ea..f8623e11f 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -279,18 +279,18 @@ class App(DOMNode): return DOMQuery(self.view, selector) - def get_child(self, selector: str) -> DOMNode: - """Shorthand for self.view.get_child(selector: str) + def get_child(self, id: str) -> DOMNode: + """Shorthand for self.view.get_child(id: str) Returns the first child (immediate descendent) of this DOMNode - matching the selector. + with the given ID. Args: - selector (str): A CSS selector. + id (str): The ID of the node to search for. Returns: - DOMNode: The first child of this node which matches the selector. + DOMNode: The first child of this node with the specified ID. """ - return self.view.get_child(selector) + return self.view.get_child(id) def update_styles(self) -> None: """Request update of styles. diff --git a/src/textual/dom.py b/src/textual/dom.py index c3434567b..9e9a35341 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -12,8 +12,9 @@ from ._node_list import NodeList from .css._error_tools import friendly_list from .css.constants import VALID_DISPLAY, VALID_VISIBILITY from .css.errors import StyleValueError -from .css.styles import Styles, RenderStyles from .css.parse import parse_declarations +from .css.styles import Styles, RenderStyles +from .css.query import NoMatchingNodesError from .message_pump import MessagePump if TYPE_CHECKING: @@ -282,19 +283,19 @@ class DOMNode(MessagePump): if node.children: push(iter(node.children)) - def get_child(self, selector: str) -> DOMNode: - """Return the first child (immediate descendent) of this DOMNode matching a selector. + def get_child(self, id: str) -> DOMNode: + """Return the first child (immediate descendent) of this node with the given ID. Args: - selector (str): A CSS selector. + id (str): The ID of the child. Returns: - DOMNode: The first child of this node which matches the selector. + DOMNode: The first child of this node with the ID. """ - from .css.query import DOMQuery - - query = DOMQuery(selector=selector, nodes=list(self.children)) - return query.first() + for child in self.children: + if child.id == id: + return child + raise NoMatchingNodesError(f"No child found with id={id!r}") def query(self, selector: str | None = None) -> DOMQuery: """Get a DOM query. diff --git a/tests/test_animator.py b/tests/test_animator.py index 1b7113692..6f9e500df 100644 --- a/tests/test_animator.py +++ b/tests/test_animator.py @@ -169,7 +169,7 @@ def test_animatable(): assert animate_test.bar.value == 50.0 -class TestAnimator(Animator): +class MockAnimator(Animator): """A mock animator.""" def __init__(self, *args) -> None: @@ -187,7 +187,7 @@ class TestAnimator(Animator): def test_animator(): target = Mock() - animator = TestAnimator(target) + animator = MockAnimator(target) animate_test = AnimateTest() # Animate attribute "foo" on animate_test to 100.0 in 10 seconds @@ -226,7 +226,7 @@ def test_animator(): def test_bound_animator(): target = Mock() - animator = TestAnimator(target) + animator = MockAnimator(target) animate_test = AnimateTest() # Bind an animator so it animates attributes on the given object diff --git a/tests/test_dom.py b/tests/test_dom.py index b8eabfcba..985b63fae 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -29,12 +29,8 @@ def test_display_set_invalid_value(): @pytest.fixture def parent(): parent = DOMNode(id="parent") - child1 = DOMNode(id="child1") - child1.add_class("foo") child2 = DOMNode(id="child2") - child2.add_class("bar") - grandchild1 = DOMNode(id="grandchild1") child1.add_child(grandchild1) @@ -45,17 +41,17 @@ def parent(): def test_get_child_gets_first_child(parent): - child = parent.get_child(".foo") + child = parent.get_child(id="child1") assert child.id == "child1" - assert child.get_child("#grandchild1").id == "grandchild1" - assert parent.get_child(".bar").id == "child2" + assert child.get_child(id="grandchild1").id == "grandchild1" + assert parent.get_child(id="child2").id == "child2" def test_get_child_no_matching_child(parent): with pytest.raises(NoMatchingNodesError): - parent.get_child("#doesnt-exist") + parent.get_child(id="doesnt-exist") def test_get_child_only_immediate_descendents(parent): with pytest.raises(NoMatchingNodesError): - parent.get_child("#grandchild1") + parent.get_child(id="grandchild1")