Add get_child_by_id and get_widget_by_id (#1146)

* Add get_child_by_id and get_widget_by_id

* Remove redundant code

* Add unit tests for app-level get_child_by_id and get_widget_by_id

* Remove redundant test fixture injection

* Update CHANGELOG

* Enforce uniqueness of ID amongst widget children

* Enforce unique widget IDs amongst widgets mounted together

* Update CHANGELOG.md

* Ensuring unique IDs in a more logical place

* Add docstring to NodeList._get_by_id

* Dont use duplicate IDs in tests, dont mount 2000 widgets

* Mounting less widgets in a unit test

* Reword error message

* Use lower-level depth first search in get_widget_by_id to break out early
This commit is contained in:
darrenburns
2022-11-16 15:29:59 +00:00
committed by GitHub
parent a465f5c236
commit df37a9b90a
10 changed files with 219 additions and 64 deletions

View File

@@ -7,9 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
## [0.5.0] - Unreleased
### Added
- Add get_child_by_id and get_widget_by_id, remove get_child https://github.com/Textualize/textual/pull/1146
- Add easing parameter to Widget.scroll_* methods https://github.com/Textualize/textual/pull/1144
- Added Widget.call_later which invokes a callback on idle.
- `DOMNode.ancestors` no longer includes `self`.

View File

@@ -8,6 +8,10 @@ if TYPE_CHECKING:
from .widget import Widget
class DuplicateIds(Exception):
pass
@rich.repr.auto(angular=True)
class NodeList(Sequence):
"""
@@ -21,6 +25,12 @@ class NodeList(Sequence):
# The nodes in the list
self._nodes: list[Widget] = []
self._nodes_set: set[Widget] = set()
# We cache widgets by their IDs too for a quick lookup
# Note that only widgets with IDs are cached like this, so
# this cache will likely hold fewer values than self._nodes.
self._nodes_by_id: dict[str, Widget] = {}
# Increments when list is updated (used for caching)
self._updates = 0
@@ -53,6 +63,10 @@ class NodeList(Sequence):
"""
return self._nodes.index(widget)
def _get_by_id(self, widget_id: str) -> Widget | None:
"""Get the widget for the given widget_id, or None if there's no matches in this list"""
return self._nodes_by_id.get(widget_id)
def _append(self, widget: Widget) -> None:
"""Append a Widget.
@@ -62,6 +76,10 @@ class NodeList(Sequence):
if widget not in self._nodes_set:
self._nodes.append(widget)
self._nodes_set.add(widget)
widget_id = widget.id
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
def _insert(self, index: int, widget: Widget) -> None:
@@ -73,8 +91,20 @@ class NodeList(Sequence):
if widget not in self._nodes_set:
self._nodes.insert(index, widget)
self._nodes_set.add(widget)
widget_id = widget.id
if widget_id is not None:
self._ensure_unique_id(widget_id)
self._nodes_by_id[widget_id] = widget
self._updates += 1
def _ensure_unique_id(self, widget_id: str) -> None:
if widget_id in self._nodes_by_id:
raise DuplicateIds(
f"Tried to insert a widget with ID {widget_id!r}, but a widget {self._nodes_by_id[widget_id]!r} "
f"already exists with that ID in this list of children. "
f"The children of a widget must have unique IDs."
)
def _remove(self, widget: Widget) -> None:
"""Remove a widget from the list.
@@ -86,6 +116,9 @@ class NodeList(Sequence):
if widget in self._nodes_set:
del self._nodes[self._nodes.index(widget)]
self._nodes_set.remove(widget)
widget_id = widget.id
if widget_id in self._nodes_by_id:
del self._nodes_by_id[widget_id]
self._updates += 1
def _clear(self) -> None:
@@ -93,6 +126,7 @@ class NodeList(Sequence):
if self._nodes:
self._nodes.clear()
self._nodes_set.clear()
self._nodes_by_id.clear()
self._updates += 1
def __iter__(self) -> Iterator[Widget]:

View File

@@ -60,7 +60,7 @@ from .messages import CallbackType
from .reactive import Reactive
from .renderables.blank import Blank
from .screen import Screen
from .widget import AwaitMount, Widget
from .widget import AwaitMount, Widget, MountError
if TYPE_CHECKING:
from .devtools.client import DevtoolsClient
@@ -873,7 +873,7 @@ class App(Generic[ReturnType], DOMNode):
def render(self) -> RenderableType:
return Blank(self.styles.background)
def get_child(self, id: str) -> DOMNode:
def get_child_by_id(self, id: str) -> Widget:
"""Shorthand for self.screen.get_child(id: str)
Returns the first child (immediate descendent) of this DOMNode
with the given ID.
@@ -887,7 +887,26 @@ class App(Generic[ReturnType], DOMNode):
Raises:
NoMatches: if no children could be found for this ID
"""
return self.screen.get_child(id)
return self.screen.get_child_by_id(id)
def get_widget_by_id(self, id: str) -> Widget:
"""Shorthand for self.screen.get_widget_by_id(id)
Return the first descendant widget with the given ID.
Performs a breadth-first search rooted at the current screen.
It will not return the Screen if that matches the ID.
To get the screen, use `self.screen`.
Args:
id (str): The ID to search for in the subtree
Returns:
DOMNode: The first descendant encountered with this ID.
Raises:
NoMatches: if no children could be found for this ID
"""
return self.screen.get_widget_by_id(id)
def update_styles(self, node: DOMNode | None = None) -> None:
"""Request update of styles.

View File

@@ -28,7 +28,6 @@ from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
from .css.errors import DeclarationError, StyleValueError
from .css.parse import parse_declarations
from .css.query import NoMatches
from .css.styles import RenderStyles, Styles
from .css.tokenize import IDENTIFIER
from .message_pump import MessagePump
@@ -645,7 +644,6 @@ class DOMNode(MessagePump):
list[DOMNode] | list[WalkType]: A list of nodes.
"""
check_type = filter_type or DOMNode
node_generator = (
@@ -661,23 +659,6 @@ class DOMNode(MessagePump):
nodes.reverse()
return cast("list[DOMNode]", nodes)
def get_child(self, id: str) -> DOMNode:
"""Return the first child (immediate descendent) of this node with the given ID.
Args:
id (str): The ID of the child.
Returns:
DOMNode: The first child of this node with the ID.
Raises:
NoMatches: if no children could be found for this ID
"""
for child in self.children:
if child.id == id:
return child
raise NoMatches(f"No child found with id={id!r}")
ExpectType = TypeVar("ExpectType", bound="Widget")
@overload

View File

@@ -4,7 +4,7 @@ from collections import deque
from typing import Iterable, Iterator, TypeVar, overload, TYPE_CHECKING
if TYPE_CHECKING:
from .dom import DOMNode
from textual.dom import DOMNode
WalkType = TypeVar("WalkType", bound=DOMNode)
@@ -51,6 +51,8 @@ def walk_depth_first(
Iterable[DOMNode] | Iterable[WalkType]: An iterable of DOMNodes, or the type specified in ``filter_type``.
"""
from textual.dom import DOMNode
stack: list[Iterator[DOMNode]] = [iter(root.children)]
pop = stack.pop
push = stack.append
@@ -111,6 +113,8 @@ def walk_breadth_first(
Iterable[DOMNode] | Iterable[WalkType]: An iterable of DOMNodes, or the type specified in ``filter_type``.
"""
from textual.dom import DOMNode
queue: deque[DOMNode] = deque()
popleft = queue.popleft
extend = queue.extend

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections import Counter
from asyncio import Lock, wait, create_task, Event as AsyncEvent
from fractions import Fraction
from itertools import islice
@@ -41,6 +42,7 @@ from ._styles_cache import StylesCache
from ._types import Lines
from .binding import NoBinding
from .box_model import BoxModel, get_box_model
from .css.query import NoMatches
from .css.scalar import ScalarOffset
from .dom import DOMNode, NoScreen
from .geometry import Offset, Region, Size, Spacing, clamp
@@ -50,6 +52,7 @@ from .messages import CallbackType
from .reactive import Reactive
from .render import measure
from .await_remove import AwaitRemove
from .walk import walk_depth_first
if TYPE_CHECKING:
from .app import App, ComposeResult
@@ -334,6 +337,43 @@ class Widget(DOMNode):
def offset(self, offset: Offset) -> None:
self.styles.offset = ScalarOffset.from_offset(offset)
def get_child_by_id(self, id: str) -> Widget:
"""Return the first child (immediate descendent) of this node with the given ID.
Args:
id (str): The ID of the child.
Returns:
DOMNode: The first child of this node with the ID.
Raises:
NoMatches: if no children could be found for this ID
"""
child = self.children._get_by_id(id)
if child is not None:
return child
raise NoMatches(f"No child found with id={id!r}")
def get_widget_by_id(self, id: str) -> Widget:
"""Return the first descendant widget with the given ID.
Performs a depth-first search rooted at this widget.
Args:
id (str): The ID to search for in the subtree
Returns:
DOMNode: The first descendant encountered with this ID.
Raises:
NoMatches: if no children could be found for this ID
"""
for child in walk_depth_first(self):
try:
return child.get_child_by_id(id)
except NoMatches:
pass
raise NoMatches(f"No descendant found with id={id!r}")
def get_component_rich_style(self, name: str) -> Style:
"""Get a *Rich* style for a component.
@@ -461,6 +501,20 @@ class Widget(DOMNode):
provided a ``MountError`` will be raised.
"""
# Check for duplicate IDs in the incoming widgets
ids_to_mount = [widget.id for widget in widgets if widget.id is not None]
unique_ids = set(ids_to_mount)
num_unique_ids = len(unique_ids)
num_widgets_with_ids = len(ids_to_mount)
if num_unique_ids != num_widgets_with_ids:
counter = Counter(widget.id for widget in widgets)
for widget_id, count in counter.items():
if count > 1:
raise MountError(
f"Tried to insert {count!r} widgets with the same ID {widget_id!r}. "
f"Widget IDs must be unique."
)
# Saying you want to mount before *and* after something is an error.
if before is not None and after is not None:
raise MountError(

View File

@@ -1,7 +1,6 @@
import pytest
from textual.css.errors import StyleValueError
from textual.css.query import NoMatches
from textual.dom import DOMNode, BadIdentifier
@@ -26,37 +25,6 @@ def test_display_set_invalid_value():
node.display = "blah"
@pytest.fixture
def parent():
parent = DOMNode(id="parent")
child1 = DOMNode(id="child1")
child2 = DOMNode(id="child2")
grandchild1 = DOMNode(id="grandchild1")
child1._add_child(grandchild1)
parent._add_child(child1)
parent._add_child(child2)
yield parent
def test_get_child_gets_first_child(parent):
child = parent.get_child(id="child1")
assert child.id == "child1"
assert child.get_child(id="grandchild1").id == "grandchild1"
assert parent.get_child(id="child2").id == "child2"
def test_get_child_no_matching_child(parent):
with pytest.raises(NoMatches):
parent.get_child(id="doesnt-exist")
def test_get_child_only_immediate_descendents(parent):
with pytest.raises(NoMatches):
parent.get_child(id="grandchild1")
def test_validate():
with pytest.raises(BadIdentifier):
DOMNode(id="23")

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from textual.app import App, ComposeResult
from textual import events
from textual.containers import Container

View File

@@ -1,9 +1,13 @@
import pytest
import rich
from textual.app import App
from textual._node_list import DuplicateIds
from textual.app import App, ComposeResult
from textual.css.errors import StyleValueError
from textual.css.query import NoMatches
from textual.dom import DOMNode
from textual.geometry import Size
from textual.widget import Widget
from textual.widget import Widget, MountError
@pytest.mark.parametrize(
@@ -64,3 +68,92 @@ def test_widget_content_width():
height = widget3.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 3
class GetByIdApp(App):
def compose(self) -> ComposeResult:
grandchild1 = Widget(id="grandchild1")
child1 = Widget(grandchild1, id="child1")
child2 = Widget(id="child2")
yield Widget(
child1,
child2,
id="parent",
)
@pytest.fixture
async def hierarchy_app():
app = GetByIdApp()
async with app.run_test():
yield app
@pytest.fixture
async def parent(hierarchy_app):
yield hierarchy_app.get_widget_by_id("parent")
def test_get_child_by_id_gets_first_child(parent):
child = parent.get_child_by_id(id="child1")
assert child.id == "child1"
assert child.get_child_by_id(id="grandchild1").id == "grandchild1"
assert parent.get_child_by_id(id="child2").id == "child2"
def test_get_child_by_id_no_matching_child(parent):
with pytest.raises(NoMatches):
parent.get_child_by_id(id="doesnt-exist")
def test_get_child_by_id_only_immediate_descendents(parent):
with pytest.raises(NoMatches):
parent.get_child_by_id(id="grandchild1")
def test_get_widget_by_id_no_matching_child(parent):
with pytest.raises(NoMatches):
parent.get_widget_by_id(id="i-dont-exist")
def test_get_widget_by_id_non_immediate_descendants(parent):
result = parent.get_widget_by_id("grandchild1")
assert result.id == "grandchild1"
def test_get_widget_by_id_immediate_descendants(parent):
result = parent.get_widget_by_id("child1")
assert result.id == "child1"
def test_get_widget_by_id_doesnt_return_self(parent):
with pytest.raises(NoMatches):
parent.get_widget_by_id("parent")
def test_get_widgets_app_delegated(hierarchy_app, parent):
# Check that get_child_by_id finds the parent, which is a child of the default Screen
queried_parent = hierarchy_app.get_child_by_id("parent")
assert queried_parent is parent
# Check that the grandchild (descendant of the default screen) is found
grandchild = hierarchy_app.get_widget_by_id("grandchild1")
assert grandchild.id == "grandchild1"
def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(parent):
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
with pytest.raises(MountError):
parent.mount(widget1, widget2)
def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(parent):
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
parent.mount(widget1)
with pytest.raises(DuplicateIds):
parent.mount(widget2)

View File

@@ -15,18 +15,18 @@ async def test_remove_single_widget():
async def test_many_remove_all_widgets():
"""It should be possible to remove all widgets on a multi-widget screen."""
async with App().run_test() as pilot:
await pilot.app.mount(*[Static() for _ in range(1000)])
assert len(pilot.app.screen.children) == 1000
await pilot.app.mount(*[Static() for _ in range(10)])
assert len(pilot.app.screen.children) == 10
await pilot.app.query(Static).remove()
assert len(pilot.app.screen.children) == 0
async def test_many_remove_some_widgets():
"""It should be possible to remove some widgets on a multi-widget screen."""
async with App().run_test() as pilot:
await pilot.app.mount(*[Static(id=f"is-{n%2}") for n in range(1000)])
assert len(pilot.app.screen.children) == 1000
await pilot.app.query("#is-0").remove()
assert len(pilot.app.screen.children) == 500
await pilot.app.mount(*[Static(classes=f"is-{n%2}") for n in range(10)])
assert len(pilot.app.screen.children) == 10
await pilot.app.query(".is-0").remove()
assert len(pilot.app.screen.children) == 5
async def test_remove_branch():
"""It should be possible to remove a whole branch in the DOM."""