diff --git a/CHANGELOG.md b/CHANGELOG.md index dd2fdb331..c300711d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [0.5.0] - Unreleased - ### Added +- Add get_child_by_id and get_widget_by_id, remove get_child https://github.com/Textualize/textual/pull/1146 - Add easing parameter to Widget.scroll_* methods https://github.com/Textualize/textual/pull/1144 - Added Widget.call_later which invokes a callback on idle. - `DOMNode.ancestors` no longer includes `self`. diff --git a/src/textual/_node_list.py b/src/textual/_node_list.py index 5a9a31486..d48597090 100644 --- a/src/textual/_node_list.py +++ b/src/textual/_node_list.py @@ -8,6 +8,10 @@ if TYPE_CHECKING: from .widget import Widget +class DuplicateIds(Exception): + pass + + @rich.repr.auto(angular=True) class NodeList(Sequence): """ @@ -21,6 +25,12 @@ class NodeList(Sequence): # The nodes in the list self._nodes: list[Widget] = [] self._nodes_set: set[Widget] = set() + + # We cache widgets by their IDs too for a quick lookup + # Note that only widgets with IDs are cached like this, so + # this cache will likely hold fewer values than self._nodes. + self._nodes_by_id: dict[str, Widget] = {} + # Increments when list is updated (used for caching) self._updates = 0 @@ -53,6 +63,10 @@ class NodeList(Sequence): """ return self._nodes.index(widget) + def _get_by_id(self, widget_id: str) -> Widget | None: + """Get the widget for the given widget_id, or None if there's no matches in this list""" + return self._nodes_by_id.get(widget_id) + def _append(self, widget: Widget) -> None: """Append a Widget. @@ -62,6 +76,10 @@ class NodeList(Sequence): if widget not in self._nodes_set: self._nodes.append(widget) self._nodes_set.add(widget) + widget_id = widget.id + if widget_id is not None: + self._ensure_unique_id(widget_id) + self._nodes_by_id[widget_id] = widget self._updates += 1 def _insert(self, index: int, widget: Widget) -> None: @@ -73,8 +91,20 @@ class NodeList(Sequence): if widget not in self._nodes_set: self._nodes.insert(index, widget) self._nodes_set.add(widget) + widget_id = widget.id + if widget_id is not None: + self._ensure_unique_id(widget_id) + self._nodes_by_id[widget_id] = widget self._updates += 1 + def _ensure_unique_id(self, widget_id: str) -> None: + if widget_id in self._nodes_by_id: + raise DuplicateIds( + f"Tried to insert a widget with ID {widget_id!r}, but a widget {self._nodes_by_id[widget_id]!r} " + f"already exists with that ID in this list of children. " + f"The children of a widget must have unique IDs." + ) + def _remove(self, widget: Widget) -> None: """Remove a widget from the list. @@ -86,6 +116,9 @@ class NodeList(Sequence): if widget in self._nodes_set: del self._nodes[self._nodes.index(widget)] self._nodes_set.remove(widget) + widget_id = widget.id + if widget_id in self._nodes_by_id: + del self._nodes_by_id[widget_id] self._updates += 1 def _clear(self) -> None: @@ -93,6 +126,7 @@ class NodeList(Sequence): if self._nodes: self._nodes.clear() self._nodes_set.clear() + self._nodes_by_id.clear() self._updates += 1 def __iter__(self) -> Iterator[Widget]: diff --git a/src/textual/app.py b/src/textual/app.py index 4328854ba..b5526f73c 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -60,7 +60,7 @@ from .messages import CallbackType from .reactive import Reactive from .renderables.blank import Blank from .screen import Screen -from .widget import AwaitMount, Widget +from .widget import AwaitMount, Widget, MountError if TYPE_CHECKING: from .devtools.client import DevtoolsClient @@ -873,7 +873,7 @@ class App(Generic[ReturnType], DOMNode): def render(self) -> RenderableType: return Blank(self.styles.background) - def get_child(self, id: str) -> DOMNode: + def get_child_by_id(self, id: str) -> Widget: """Shorthand for self.screen.get_child(id: str) Returns the first child (immediate descendent) of this DOMNode with the given ID. @@ -887,7 +887,26 @@ class App(Generic[ReturnType], DOMNode): Raises: NoMatches: if no children could be found for this ID """ - return self.screen.get_child(id) + return self.screen.get_child_by_id(id) + + def get_widget_by_id(self, id: str) -> Widget: + """Shorthand for self.screen.get_widget_by_id(id) + Return the first descendant widget with the given ID. + + Performs a breadth-first search rooted at the current screen. + It will not return the Screen if that matches the ID. + To get the screen, use `self.screen`. + + Args: + id (str): The ID to search for in the subtree + + Returns: + DOMNode: The first descendant encountered with this ID. + + Raises: + NoMatches: if no children could be found for this ID + """ + return self.screen.get_widget_by_id(id) def update_styles(self, node: DOMNode | None = None) -> None: """Request update of styles. diff --git a/src/textual/dom.py b/src/textual/dom.py index 0373f9e04..4d5509804 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -28,7 +28,6 @@ from .css._error_tools import friendly_list from .css.constants import VALID_DISPLAY, VALID_VISIBILITY from .css.errors import DeclarationError, StyleValueError from .css.parse import parse_declarations -from .css.query import NoMatches from .css.styles import RenderStyles, Styles from .css.tokenize import IDENTIFIER from .message_pump import MessagePump @@ -645,7 +644,6 @@ class DOMNode(MessagePump): list[DOMNode] | list[WalkType]: A list of nodes. """ - check_type = filter_type or DOMNode node_generator = ( @@ -661,23 +659,6 @@ class DOMNode(MessagePump): nodes.reverse() return cast("list[DOMNode]", nodes) - def get_child(self, id: str) -> DOMNode: - """Return the first child (immediate descendent) of this node with the given ID. - - Args: - id (str): The ID of the child. - - Returns: - DOMNode: The first child of this node with the ID. - - Raises: - NoMatches: if no children could be found for this ID - """ - for child in self.children: - if child.id == id: - return child - raise NoMatches(f"No child found with id={id!r}") - ExpectType = TypeVar("ExpectType", bound="Widget") @overload diff --git a/src/textual/walk.py b/src/textual/walk.py index 1126d9d30..aa2e3467e 100644 --- a/src/textual/walk.py +++ b/src/textual/walk.py @@ -4,7 +4,7 @@ from collections import deque from typing import Iterable, Iterator, TypeVar, overload, TYPE_CHECKING if TYPE_CHECKING: - from .dom import DOMNode + from textual.dom import DOMNode WalkType = TypeVar("WalkType", bound=DOMNode) @@ -51,6 +51,8 @@ def walk_depth_first( Iterable[DOMNode] | Iterable[WalkType]: An iterable of DOMNodes, or the type specified in ``filter_type``. """ + from textual.dom import DOMNode + stack: list[Iterator[DOMNode]] = [iter(root.children)] pop = stack.pop push = stack.append @@ -111,6 +113,8 @@ def walk_breadth_first( Iterable[DOMNode] | Iterable[WalkType]: An iterable of DOMNodes, or the type specified in ``filter_type``. """ + from textual.dom import DOMNode + queue: deque[DOMNode] = deque() popleft = queue.popleft extend = queue.extend diff --git a/src/textual/widget.py b/src/textual/widget.py index 2de261227..8d993453c 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import Counter from asyncio import Lock, wait, create_task, Event as AsyncEvent from fractions import Fraction from itertools import islice @@ -41,6 +42,7 @@ from ._styles_cache import StylesCache from ._types import Lines from .binding import NoBinding from .box_model import BoxModel, get_box_model +from .css.query import NoMatches from .css.scalar import ScalarOffset from .dom import DOMNode, NoScreen from .geometry import Offset, Region, Size, Spacing, clamp @@ -50,6 +52,7 @@ from .messages import CallbackType from .reactive import Reactive from .render import measure from .await_remove import AwaitRemove +from .walk import walk_depth_first if TYPE_CHECKING: from .app import App, ComposeResult @@ -334,6 +337,43 @@ class Widget(DOMNode): def offset(self, offset: Offset) -> None: self.styles.offset = ScalarOffset.from_offset(offset) + def get_child_by_id(self, id: str) -> Widget: + """Return the first child (immediate descendent) of this node with the given ID. + + Args: + id (str): The ID of the child. + + Returns: + DOMNode: The first child of this node with the ID. + + Raises: + NoMatches: if no children could be found for this ID + """ + child = self.children._get_by_id(id) + if child is not None: + return child + raise NoMatches(f"No child found with id={id!r}") + + def get_widget_by_id(self, id: str) -> Widget: + """Return the first descendant widget with the given ID. + Performs a depth-first search rooted at this widget. + + Args: + id (str): The ID to search for in the subtree + + Returns: + DOMNode: The first descendant encountered with this ID. + + Raises: + NoMatches: if no children could be found for this ID + """ + for child in walk_depth_first(self): + try: + return child.get_child_by_id(id) + except NoMatches: + pass + raise NoMatches(f"No descendant found with id={id!r}") + def get_component_rich_style(self, name: str) -> Style: """Get a *Rich* style for a component. @@ -461,6 +501,20 @@ class Widget(DOMNode): provided a ``MountError`` will be raised. """ + # Check for duplicate IDs in the incoming widgets + ids_to_mount = [widget.id for widget in widgets if widget.id is not None] + unique_ids = set(ids_to_mount) + num_unique_ids = len(unique_ids) + num_widgets_with_ids = len(ids_to_mount) + if num_unique_ids != num_widgets_with_ids: + counter = Counter(widget.id for widget in widgets) + for widget_id, count in counter.items(): + if count > 1: + raise MountError( + f"Tried to insert {count!r} widgets with the same ID {widget_id!r}. " + f"Widget IDs must be unique." + ) + # Saying you want to mount before *and* after something is an error. if before is not None and after is not None: raise MountError( diff --git a/tests/test_dom.py b/tests/test_dom.py index 5a713193a..e925c5c14 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -1,7 +1,6 @@ import pytest from textual.css.errors import StyleValueError -from textual.css.query import NoMatches from textual.dom import DOMNode, BadIdentifier @@ -26,37 +25,6 @@ def test_display_set_invalid_value(): node.display = "blah" -@pytest.fixture -def parent(): - parent = DOMNode(id="parent") - child1 = DOMNode(id="child1") - child2 = DOMNode(id="child2") - grandchild1 = DOMNode(id="grandchild1") - child1._add_child(grandchild1) - - parent._add_child(child1) - parent._add_child(child2) - - yield parent - - -def test_get_child_gets_first_child(parent): - child = parent.get_child(id="child1") - assert child.id == "child1" - assert child.get_child(id="grandchild1").id == "grandchild1" - assert parent.get_child(id="child2").id == "child2" - - -def test_get_child_no_matching_child(parent): - with pytest.raises(NoMatches): - parent.get_child(id="doesnt-exist") - - -def test_get_child_only_immediate_descendents(parent): - with pytest.raises(NoMatches): - parent.get_child(id="grandchild1") - - def test_validate(): with pytest.raises(BadIdentifier): DOMNode(id="23") diff --git a/tests/test_unmount.py b/tests/test_unmount.py index 3e6a5ed0f..4611ff7d9 100644 --- a/tests/test_unmount.py +++ b/tests/test_unmount.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from textual.app import App, ComposeResult from textual import events from textual.containers import Container diff --git a/tests/test_widget.py b/tests/test_widget.py index 9c81c3fe4..e3c0a618c 100644 --- a/tests/test_widget.py +++ b/tests/test_widget.py @@ -1,9 +1,13 @@ import pytest +import rich -from textual.app import App +from textual._node_list import DuplicateIds +from textual.app import App, ComposeResult from textual.css.errors import StyleValueError +from textual.css.query import NoMatches +from textual.dom import DOMNode from textual.geometry import Size -from textual.widget import Widget +from textual.widget import Widget, MountError @pytest.mark.parametrize( @@ -64,3 +68,92 @@ def test_widget_content_width(): height = widget3.get_content_height(Size(20, 20), Size(80, 24), width) assert width == 3 assert height == 3 + + +class GetByIdApp(App): + def compose(self) -> ComposeResult: + grandchild1 = Widget(id="grandchild1") + child1 = Widget(grandchild1, id="child1") + child2 = Widget(id="child2") + + yield Widget( + child1, + child2, + id="parent", + ) + + +@pytest.fixture +async def hierarchy_app(): + app = GetByIdApp() + async with app.run_test(): + yield app + + +@pytest.fixture +async def parent(hierarchy_app): + yield hierarchy_app.get_widget_by_id("parent") + + +def test_get_child_by_id_gets_first_child(parent): + child = parent.get_child_by_id(id="child1") + assert child.id == "child1" + assert child.get_child_by_id(id="grandchild1").id == "grandchild1" + assert parent.get_child_by_id(id="child2").id == "child2" + + +def test_get_child_by_id_no_matching_child(parent): + with pytest.raises(NoMatches): + parent.get_child_by_id(id="doesnt-exist") + + +def test_get_child_by_id_only_immediate_descendents(parent): + with pytest.raises(NoMatches): + parent.get_child_by_id(id="grandchild1") + + +def test_get_widget_by_id_no_matching_child(parent): + with pytest.raises(NoMatches): + parent.get_widget_by_id(id="i-dont-exist") + + +def test_get_widget_by_id_non_immediate_descendants(parent): + result = parent.get_widget_by_id("grandchild1") + assert result.id == "grandchild1" + + +def test_get_widget_by_id_immediate_descendants(parent): + result = parent.get_widget_by_id("child1") + assert result.id == "child1" + + +def test_get_widget_by_id_doesnt_return_self(parent): + with pytest.raises(NoMatches): + parent.get_widget_by_id("parent") + + +def test_get_widgets_app_delegated(hierarchy_app, parent): + # Check that get_child_by_id finds the parent, which is a child of the default Screen + queried_parent = hierarchy_app.get_child_by_id("parent") + assert queried_parent is parent + + # Check that the grandchild (descendant of the default screen) is found + grandchild = hierarchy_app.get_widget_by_id("grandchild1") + assert grandchild.id == "grandchild1" + + +def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(parent): + widget1 = Widget(id="hello") + widget2 = Widget(id="hello") + + with pytest.raises(MountError): + parent.mount(widget1, widget2) + + +def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(parent): + widget1 = Widget(id="hello") + widget2 = Widget(id="hello") + + parent.mount(widget1) + with pytest.raises(DuplicateIds): + parent.mount(widget2) diff --git a/tests/test_widget_removing.py b/tests/test_widget_removing.py index 341866a9f..6371cf9d3 100644 --- a/tests/test_widget_removing.py +++ b/tests/test_widget_removing.py @@ -15,18 +15,18 @@ async def test_remove_single_widget(): async def test_many_remove_all_widgets(): """It should be possible to remove all widgets on a multi-widget screen.""" async with App().run_test() as pilot: - await pilot.app.mount(*[Static() for _ in range(1000)]) - assert len(pilot.app.screen.children) == 1000 + await pilot.app.mount(*[Static() for _ in range(10)]) + assert len(pilot.app.screen.children) == 10 await pilot.app.query(Static).remove() assert len(pilot.app.screen.children) == 0 async def test_many_remove_some_widgets(): """It should be possible to remove some widgets on a multi-widget screen.""" async with App().run_test() as pilot: - await pilot.app.mount(*[Static(id=f"is-{n%2}") for n in range(1000)]) - assert len(pilot.app.screen.children) == 1000 - await pilot.app.query("#is-0").remove() - assert len(pilot.app.screen.children) == 500 + await pilot.app.mount(*[Static(classes=f"is-{n%2}") for n in range(10)]) + assert len(pilot.app.screen.children) == 10 + await pilot.app.query(".is-0").remove() + assert len(pilot.app.screen.children) == 5 async def test_remove_branch(): """It should be possible to remove a whole branch in the DOM."""