diff --git a/src/textual/_node_list.py b/src/textual/_node_list.py index d48597090..7032b0cc7 100644 --- a/src/textual/_node_list.py +++ b/src/textual/_node_list.py @@ -101,8 +101,8 @@ class NodeList(Sequence): if widget_id in self._nodes_by_id: raise DuplicateIds( f"Tried to insert a widget with ID {widget_id!r}, but a widget {self._nodes_by_id[widget_id]!r} " - f"already exists with that ID in this list of children. " - f"The children of a widget must have unique IDs." + "already exists with that ID in this list of children. " + "The children of a widget must have unique IDs." ) def _remove(self, widget: Widget) -> None: diff --git a/src/textual/app.py b/src/textual/app.py index 27fbc352a..369583427 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -26,6 +26,7 @@ from typing import ( TypeVar, Union, cast, + overload, ) from weakref import WeakSet, WeakValueDictionary @@ -892,23 +893,52 @@ class App(Generic[ReturnType], DOMNode): def render(self) -> RenderableType: return Blank(self.styles.background) + ExpectType = TypeVar("ExpectType", bound=Widget) + + @overload def get_child_by_id(self, id: str) -> Widget: + ... + + @overload + def get_child_by_id(self, id: str, expect_type: type[ExpectType]) -> ExpectType: + ... + + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: """Shorthand for self.screen.get_child(id: str) Returns the first child (immediate descendent) of this DOMNode with the given ID. Args: id (str): The ID of the node to search for. + expect_type (type | None, optional): Require the object be of the supplied type, or None for any type. + Defaults to None. Returns: - DOMNode: The first child of this node with the specified ID. + ExpectType | Widget: The first child of this node with the specified ID. Raises: NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. """ - return self.screen.get_child_by_id(id) + return ( + self.screen.get_child_by_id(id) + if expect_type is None + else self.screen.get_child_by_id(id, expect_type) + ) + @overload def get_widget_by_id(self, id: str) -> Widget: + ... + + @overload + def get_widget_by_id(self, id: str, expect_type: type[ExpectType]) -> ExpectType: + ... + + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: """Shorthand for self.screen.get_widget_by_id(id) Return the first descendant widget with the given ID. @@ -918,14 +948,21 @@ class App(Generic[ReturnType], DOMNode): Args: id (str): The ID to search for in the subtree + expect_type (type | None, optional): Require the object be of the supplied type, or None for any type. + Defaults to None. Returns: - DOMNode: The first descendant encountered with this ID. + ExpectType | Widget: The first descendant encountered with this ID. Raises: NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. """ - return self.screen.get_widget_by_id(id) + return ( + self.screen.get_widget_by_id(id) + if expect_type is None + else self.screen.get_widget_by_id(id, expect_type) + ) def update_styles(self, node: DOMNode | None = None) -> None: """Request update of styles. @@ -1463,7 +1500,6 @@ class App(Generic[ReturnType], DOMNode): # If we don't already know about this widget... if child not in self._registry: - # Now to figure out where to place it. If we've got a `before`... if before is not None: # ...it's safe to NodeList._insert before that location. diff --git a/src/textual/widget.py b/src/textual/widget.py index 70a3b353a..5eb71a0b3 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -14,7 +14,9 @@ from typing import ( Iterable, NamedTuple, Sequence, + TypeVar, cast, + overload, ) import rich.repr @@ -44,7 +46,7 @@ from ._types import Lines from .await_remove import AwaitRemove from .binding import Binding from .box_model import BoxModel, get_box_model -from .css.query import NoMatches +from .css.query import NoMatches, WrongType from .css.scalar import ScalarOffset from .dom import DOMNode, NoScreen from .geometry import Offset, Region, Size, Spacing, clamp @@ -221,7 +223,6 @@ class Widget(DOMNode): id: str | None = None, classes: str | None = None, ) -> None: - self._size = Size(0, 0) self._container_size = Size(0, 0) self._layout_required = False @@ -349,41 +350,81 @@ class Widget(DOMNode): def offset(self, offset: Offset) -> None: self.styles.offset = ScalarOffset.from_offset(offset) + ExpectType = TypeVar("ExpectType", bound="Widget") + + @overload def get_child_by_id(self, id: str) -> Widget: + ... + + @overload + def get_child_by_id(self, id: str, expect_type: type[ExpectType]) -> ExpectType: + ... + + def get_child_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: """Return the first child (immediate descendent) of this node with the given ID. Args: id (str): The ID of the child. + expect_type (type | None, optional): Require the object be of the supplied type, or None for any type. + Defaults to None. Returns: - DOMNode: The first child of this node with the ID. + ExpectType | Widget: The first child of this node with the ID. Raises: NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. """ child = self.children._get_by_id(id) - if child is not None: + if child is None: + raise NoMatches(f"No child found with id={id!r}") + if expect_type is None: return child - raise NoMatches(f"No child found with id={id!r}") + if not isinstance(child, expect_type): + raise WrongType( + f"Child with id={id!r} is wrong type; expected {expect_type}, got" + f" {type(child)}" + ) + return child + @overload def get_widget_by_id(self, id: str) -> Widget: + ... + + @overload + def get_widget_by_id(self, id: str, expect_type: type[ExpectType]) -> ExpectType: + ... + + def get_widget_by_id( + self, id: str, expect_type: type[ExpectType] | None = None + ) -> ExpectType | Widget: """Return the first descendant widget with the given ID. Performs a depth-first search rooted at this widget. Args: id (str): The ID to search for in the subtree + expect_type (type | None, optional): Require the object be of the supplied type, or None for any type. + Defaults to None. Returns: - DOMNode: The first descendant encountered with this ID. + ExpectType | Widget: The first descendant encountered with this ID. Raises: NoMatches: if no children could be found for this ID + WrongType: if the wrong type was found. """ for child in walk_depth_first(self): try: - return child.get_child_by_id(id) + return child.get_child_by_id(id, expect_type=expect_type) except NoMatches: pass + except WrongType as exc: + raise WrongType( + f"Descendant with id={id!r} is wrong type; expected {expect_type}," + f" got {type(child)}" + ) from exc raise NoMatches(f"No descendant found with id={id!r}") def get_component_rich_style(self, name: str, *, partial: bool = False) -> Style: @@ -530,7 +571,7 @@ class Widget(DOMNode): if count > 1: raise MountError( f"Tried to insert {count!r} widgets with the same ID {widget_id!r}. " - f"Widget IDs must be unique." + "Widget IDs must be unique." ) # Saying you want to mount before *and* after something is an error.