Merge branch 'main' into widget-doc-sweep

This commit is contained in:
Dave Pearson
2023-03-01 16:07:36 +00:00
committed by GitHub
33 changed files with 203 additions and 142 deletions

View File

@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Changed
- Widget scrolling methods (such as `Widget.scroll_home` and `Widget.scroll_end`) now perform the scroll after the next refresh https://github.com/Textualize/textual/issues/1774
- Buttons no longer accept arbitrary renderables https://github.com/Textualize/textual/issues/1870
### Fixed

View File

@@ -67,6 +67,9 @@ plugins:
default_handler: python
handlers:
python:
import:
- https://docs.python.org/3/objects.inv
- https://rich.readthedocs.io/en/stable/objects.inv
options:
show_root_heading: true
show_root_full_path: false

View File

@@ -14,3 +14,7 @@ ignore_missing_imports = True
[mypy-ipywidgets.*]
ignore_missing_imports = True
[mypy-uvloop.*]
# Ignore missing imports for optional library that isn't listed as a dependency.
ignore_missing_imports = True

View File

@@ -64,7 +64,10 @@ class Logger:
if app.devtools is None or not app.devtools.is_connected:
return
previous_frame = inspect.currentframe().f_back
current_frame = inspect.currentframe()
assert current_frame is not None
previous_frame = current_frame.f_back
assert previous_frame is not None
caller = inspect.getframeinfo(previous_frame)
_log = self._log or app._log

View File

@@ -58,6 +58,7 @@ async def invoke(callback: Callable, *params: object) -> Any:
# In debug mode we will warn about callbacks that may be stuck
def log_slow() -> None:
"""Log a message regarding a slow callback."""
assert app is not None
app.log.warning(
f"Callback {callback} is still pending after {INVOKE_TIMEOUT_WARNING} seconds"
)

View File

@@ -14,7 +14,7 @@ without having to render the entire screen.
from __future__ import annotations
from operator import itemgetter
from typing import TYPE_CHECKING, Iterable, NamedTuple, cast
from typing import TYPE_CHECKING, Callable, Iterable, NamedTuple, cast
import rich.repr
from rich.console import Console, ConsoleOptions, RenderableType, RenderResult
@@ -45,12 +45,23 @@ class ReflowResult(NamedTuple):
class MapGeometry(NamedTuple):
"""Defines the absolute location of a Widget."""
region: Region # The (screen) region occupied by the widget
order: tuple[tuple[int, ...], ...] # A tuple of ints defining the painting order
clip: Region # A region to clip the widget by (if a Widget is within a container)
virtual_size: Size # The virtual size (scrollable region) of a widget if it is a container
container_size: Size # The container size (area not occupied by scrollbars)
virtual_region: Region # The region relative to the container (but not necessarily visible)
region: Region
"""The (screen) region occupied by the widget."""
order: tuple[tuple[int, int, int], ...]
"""Tuple of tuples defining the painting order of the widget.
Each successive triple represents painting order information with regards to
ancestors in the DOM hierarchy and the last triple provides painting order
information for this specific widget.
"""
clip: Region
"""A region to clip the widget by (if a Widget is within a container)."""
virtual_size: Size
"""The virtual size (scrollable region) of a widget if it is a container."""
container_size: Size
"""The container size (area not occupied by scrollbars)."""
virtual_region: Region
"""The region relative to the container (but not necessarily visible)."""
@property
def visible_region(self) -> Region:
@@ -419,19 +430,23 @@ class Compositor:
widget: Widget,
virtual_region: Region,
region: Region,
order: tuple[tuple[int, ...], ...],
order: tuple[tuple[int, int, int], ...],
layer_order: int,
clip: Region,
visible: bool,
_MapGeometry=MapGeometry,
_MapGeometry: type[MapGeometry] = MapGeometry,
) -> None:
"""Called recursively to place a widget and its children in the map.
Args:
widget: The widget to add.
virtual_region: The Widget region relative to it's container.
region: The region the widget will occupy.
order: A tuple of ints to define the order.
order: Painting order information.
layer_order: The order of the widget in its layer.
clip: The clipping region (i.e. the viewport which contains it).
visible: Whether the widget should be visible by default.
This may be overriden by the CSS rule `visibility`.
"""
visibility = widget.styles.get_rule("visibility")
if visibility is not None:
@@ -501,11 +516,12 @@ class Compositor:
)
widget_region = sub_region + placement_scroll_offset
widget_order = (
*order,
get_layer_index(sub_widget.layer, 0),
z,
layer_order,
widget_order = order + (
(
get_layer_index(sub_widget.layer, 0),
z,
layer_order,
),
)
add_widget(
@@ -560,7 +576,7 @@ class Compositor:
root,
size.region,
size.region,
((0,),),
((0, 0, 0),),
layer_order,
size.region,
True,
@@ -818,11 +834,8 @@ class Compositor:
# Maps each cut on to a list of segments
cuts = self.cuts
# dict.fromkeys is a callable which takes a list of ints returns a dict which maps ints on to a list of Segments or None.
fromkeys = cast(
"Callable[[list[int]], dict[int, list[Segment] | None]]", dict.fromkeys
)
# A mapping of cut index to a list of segments for each line
# dict.fromkeys is a callable which takes a list of ints returns a dict which maps ints onto a Segment or None.
fromkeys = cast("Callable[[list[int]], dict[int, Strip | None]]", dict.fromkeys)
chops: list[dict[int, Strip | None]]
chops = [fromkeys(cut_set[:-1]) for cut_set in cuts]

View File

@@ -4,7 +4,7 @@ import hashlib
import os
import shlex
from pathlib import Path
from typing import Iterable
from typing import Iterable, cast
from textual._import_app import import_app
from textual.app import App
@@ -45,6 +45,7 @@ def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str
import traceback
traceback.print_exception(error)
return ""
def take_svg_screenshot(
@@ -82,6 +83,7 @@ def take_svg_screenshot(
hash = hashlib.md5()
file_paths = [app_path] + app.css_path
for path in file_paths:
assert path is not None
with open(path, "rb") as source_file:
hash.update(source_file.read())
hash.update(f"{press}-{title}-{terminal_size}".encode("utf-8"))
@@ -105,10 +107,13 @@ def take_svg_screenshot(
app.exit(svg)
svg = app.run(
headless=True,
auto_pilot=auto_pilot,
size=terminal_size,
svg = cast(
str,
app.run(
headless=True,
auto_pilot=auto_pilot,
size=terminal_size,
),
)
if app_path is not None:

View File

@@ -18,7 +18,7 @@ from .renderables.tint import Tint
from .strip import Strip
if TYPE_CHECKING:
from typing import TypeAlias
from typing_extensions import TypeAlias
from .css.styles import StylesBase
from .widget import Widget

View File

@@ -200,8 +200,8 @@ class XTermParser(Parser[events.Event]):
if not bracketed_paste:
# Was it a pressed key event that we received?
key_events = list(sequence_to_key_events(sequence))
for event in key_events:
on_token(event)
for key_event in key_events:
on_token(key_event)
if key_events:
break
# Or a mouse event?

View File

@@ -3,6 +3,11 @@ from __future__ import annotations
import ast
import re
from typing_extensions import Any, TypeAlias
ActionParseResult: TypeAlias = "tuple[str, tuple[Any, ...]]"
"""An action is its name and the arbitrary tuple of its parameters."""
class SkipAction(Exception):
"""Raise in an action to skip the action (and allow any parent bindings to run)."""
@@ -15,7 +20,7 @@ class ActionError(Exception):
re_action_params = re.compile(r"([\w\.]+)(\(.*?\))")
def parse(action: str) -> tuple[str, tuple[object, ...]]:
def parse(action: str) -> ActionParseResult:
"""Parses an action string.
Args:

View File

@@ -57,7 +57,7 @@ from ._context import active_app
from ._event_broker import NoHandler, extract_handler_actions
from ._path import _make_path_object_relative
from ._wait import wait_for_idle
from .actions import SkipAction
from .actions import ActionParseResult, SkipAction
from .await_remove import AwaitRemove
from .binding import Binding, Bindings
from .css.query import NoMatches
@@ -645,7 +645,7 @@ class App(Generic[ReturnType], DOMNode):
self,
group: LogGroup,
verbosity: LogVerbosity,
_textual_calling_frame: inspect.FrameInfo,
_textual_calling_frame: inspect.Traceback,
*objects: Any,
**kwargs,
) -> None:
@@ -1605,9 +1605,8 @@ class App(Generic[ReturnType], DOMNode):
with redirect_stdout(redirector): # type: ignore
await run_process_messages()
else:
null_file = _NullFile()
with redirect_stderr(null_file):
with redirect_stdout(null_file):
with redirect_stderr(None):
with redirect_stdout(None):
await run_process_messages()
finally:
@@ -1732,16 +1731,17 @@ class App(Generic[ReturnType], DOMNode):
if not widgets:
return []
new_widgets = list(widgets)
widget_list: Iterable[Widget]
if before is not None or after is not None:
# There's a before or after, which means there's going to be an
# insertion, so make it easier to get the new things in the
# correct order.
new_widgets = reversed(new_widgets)
widget_list = reversed(widgets)
else:
widget_list = widgets
apply_stylesheet = self.stylesheet.apply
for widget in new_widgets:
for widget in widget_list:
if not isinstance(widget, Widget):
raise AppError(f"Can't register {widget!r}; expected a Widget instance")
if widget not in self._registry:
@@ -1798,14 +1798,14 @@ class App(Generic[ReturnType], DOMNode):
async def _close_all(self) -> None:
"""Close all message pumps."""
# Close all screens on the stack
for screen in reversed(self._screen_stack):
if screen._running:
await self._prune_node(screen)
# Close all screens on the stack.
for stack_screen in reversed(self._screen_stack):
if stack_screen._running:
await self._prune_node(stack_screen)
self._screen_stack.clear()
# Close pre-defined screens
# Close pre-defined screens.
for screen in self.SCREENS.values():
if isinstance(screen, Screen) and screen._running:
await self._prune_node(screen)
@@ -1971,7 +1971,7 @@ class App(Generic[ReturnType], DOMNode):
async def action(
self,
action: str | tuple[str, tuple[str, ...]],
action: str | ActionParseResult,
default_namespace: object | None = None,
) -> bool:
"""Perform an action.
@@ -2069,7 +2069,7 @@ class App(Generic[ReturnType], DOMNode):
else:
event.stop()
if isinstance(action, (str, tuple)):
await self.action(action, default_namespace=default_namespace)
await self.action(action, default_namespace=default_namespace) # type: ignore[arg-type]
elif callable(action):
await action()
else:
@@ -2339,9 +2339,12 @@ _uvloop_init_done: bool = False
def _init_uvloop() -> None:
"""
Import and install the `uvloop` asyncio policy, if available.
"""Import and install the `uvloop` asyncio policy, if available.
This is done only once, even if the function is called multiple times.
This is provided as a nicety for users that have `uvloop` installed independently
of Textual, as `uvloop` is not listed as a Textual dependency.
"""
global _uvloop_init_done
@@ -2349,10 +2352,10 @@ def _init_uvloop() -> None:
return
try:
import uvloop
import uvloop # type: ignore[reportMissingImports]
except ImportError:
pass
else:
uvloop.install()
uvloop.install() # type: ignore[reportUnknownMemberType]
_uvloop_init_done = True

View File

@@ -92,6 +92,7 @@ class EasingApp(App):
target_position = (
END_POSITION if self.position == START_POSITION else START_POSITION
)
assert event.button.id is not None # Should be set to an easing function str.
self.animate(
"position",
value=target_position,
@@ -106,7 +107,7 @@ class EasingApp(App):
self.opacity_widget.styles.opacity = 1 - value / END_POSITION
def on_input_changed(self, event: Input.Changed):
if event.sender.id == "duration-input":
if event.input.id == "duration-input":
new_duration = _try_float(event.value)
if new_duration is not None:
self.duration = new_duration

View File

@@ -528,8 +528,8 @@ class SpacingProperty:
string (e.g. ``"blue on #f0f0f0"``).
Raises:
ValueError: When the value is malformed, e.g. a ``tuple`` with a length that is
not 1, 2, or 4.
ValueError: When the value is malformed,
e.g. a ``tuple`` with a length that is not 1, 2, or 4.
"""
_rich_traceback_omit = True
if spacing is None:
@@ -543,7 +543,9 @@ class SpacingProperty:
str(error),
help_text=spacing_wrong_number_of_values_help_text(
property_name=self.name,
num_values_supplied=len(spacing),
num_values_supplied=(
1 if isinstance(spacing, int) else len(spacing)
),
context="inline",
),
)

View File

@@ -264,7 +264,7 @@ def substitute_references(
iter_tokens = iter(tokens)
while tokens:
while True:
token = next(iter_tokens, None)
if token is None:
break
@@ -274,8 +274,7 @@ def substitute_references(
while True:
token = next(iter_tokens, None)
# TODO: Mypy error looks legit
if token.name == "whitespace":
if token is not None and token.name == "whitespace":
yield token
else:
break

View File

@@ -7,14 +7,14 @@ from .._types import CallbackType
from .scalar import Scalar, ScalarOffset
if TYPE_CHECKING:
from ..dom import DOMNode
from ..widget import Widget
from .styles import StylesBase
class ScalarAnimation(Animation):
def __init__(
self,
widget: DOMNode,
widget: Widget,
styles: StylesBase,
start_time: float,
attribute: str,

View File

@@ -335,6 +335,9 @@ class StylesBase(ABC):
if not isinstance(value, (Scalar, ScalarOffset)):
return None
from ..widget import Widget
assert isinstance(self.node, Widget)
return ScalarAnimation(
self.node,
self,
@@ -581,7 +584,9 @@ class StylesBase(ABC):
@dataclass
class Styles(StylesBase):
node: DOMNode | None = None
_rules: RulesMap = field(default_factory=dict)
_rules: RulesMap = field(
default_factory=lambda: RulesMap()
) # mypy won't be happy with `default_factory=RulesMap`
_updates: int = 0
important: set[str] = field(default_factory=set)
@@ -648,14 +653,14 @@ class Styles(StylesBase):
self._updates += 1
self._rules.clear() # type: ignore
def merge(self, other: Styles) -> None:
def merge(self, other: StylesBase) -> None:
"""Merge values from another Styles.
Args:
other: A Styles object.
"""
self._updates += 1
self._rules.update(other._rules)
self._rules.update(other.get_rules())
def merge_rules(self, rules: RulesMap) -> None:
self._updates += 1
@@ -1066,7 +1071,7 @@ class RenderStyles(StylesBase):
def refresh(self, *, layout: bool = False, children: bool = False) -> None:
self._inline_styles.refresh(layout=layout, children=children)
def merge(self, other: Styles) -> None:
def merge(self, other: StylesBase) -> None:
"""Merge values from another Styles.
Args:

View File

@@ -78,7 +78,7 @@ class StylesheetErrors:
f"{path.absolute() if path else filename}:{line_no}:{col_no}"
)
link_style = Style(
link=f"file://{path.absolute()}",
link=f"file://{path.absolute()}" if path else None,
color="red",
bold=True,
italic=True,

View File

@@ -1,4 +1,4 @@
* {
* {
transition: background 500ms in_out_cubic, color 500ms in_out_cubic;
}
@@ -125,7 +125,7 @@ DarkSwitch Switch {
}
Screen > Container {
Screen>Container {
height: 100%;
overflow: hidden;
}
@@ -222,7 +222,7 @@ LoginForm {
border: wide $background;
}
LoginForm Button{
LoginForm Button {
margin: 0 1;
width: 100%;
}
@@ -250,7 +250,7 @@ Window {
max-height: 16;
}
Window > Static {
Window>Static {
width: auto;
}

View File

@@ -205,7 +205,7 @@ class DarkSwitch(Horizontal):
def on_mount(self) -> None:
self.watch(self.app, "dark", self.on_dark_change, init=False)
def on_dark_change(self, dark: bool) -> None:
def on_dark_change(self) -> None:
self.query_one(Switch).value = self.app.dark
def on_switch_changed(self, event: Switch.Changed) -> None:
@@ -302,7 +302,7 @@ class Notification(Static):
self.remove()
class DemoApp(App):
class DemoApp(App[None]):
CSS_PATH = "demo.css"
TITLE = "Textual Demo"
BINDINGS = [

View File

@@ -35,7 +35,7 @@ class DevtoolsLog(NamedTuple):
"""
objects_or_string: tuple[Any, ...] | str
caller: inspect.FrameInfo
caller: inspect.Traceback
class DevtoolsConsole(Console):

View File

@@ -39,7 +39,10 @@ class StdoutRedirector:
if not self.devtools.is_connected:
return
previous_frame = inspect.currentframe().f_back
current_frame = inspect.currentframe()
assert current_frame is not None
previous_frame = current_frame.f_back
assert previous_frame is not None
caller = inspect.getframeinfo(previous_frame)
self._buffer.append(DevtoolsLog(string, caller=caller))

View File

@@ -5,10 +5,10 @@ import asyncio
import json
import pickle
from json import JSONDecodeError
from typing import Any, cast
from typing import Any
import msgpack
from aiohttp import WSMessage, WSMsgType
from aiohttp import WSMsgType
from aiohttp.abc import Request
from aiohttp.web_ws import WebSocketResponse
from rich.console import Console

View File

@@ -24,7 +24,7 @@ from rich.tree import Tree
from ._context import NoActiveAppError
from ._node_list import NodeList
from ._types import CallbackType
from .binding import Bindings, BindingType
from .binding import Binding, Bindings, BindingType
from .color import BLACK, WHITE, Color
from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
@@ -39,7 +39,7 @@ from .walk import walk_breadth_first, walk_depth_first
if TYPE_CHECKING:
from .app import App
from .css.query import DOMQuery
from .css.query import DOMQuery, QueryType
from .screen import Screen
from .widget import Widget
from typing_extensions import TypeAlias
@@ -276,7 +276,7 @@ class DOMNode(MessagePump):
base.__dict__.get("BINDINGS", []),
)
)
keys = {}
keys: dict[str, Binding] = {}
for bindings_ in bindings:
keys.update(bindings_.keys)
return Bindings(keys.values())
@@ -357,7 +357,7 @@ class DOMNode(MessagePump):
# Note that self.screen may not be the same as self.app.screen
from .screen import Screen
node = self
node: MessagePump | None = self
while node is not None and not isinstance(node, Screen):
node = node._parent
if not isinstance(node, Screen):
@@ -771,19 +771,17 @@ class DOMNode(MessagePump):
nodes.reverse()
return cast("list[DOMNode]", nodes)
ExpectType = TypeVar("ExpectType", bound="Widget")
@overload
def query(self, selector: str | None) -> DOMQuery[Widget]:
...
@overload
def query(self, selector: type[ExpectType]) -> DOMQuery[ExpectType]:
def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]:
...
def query(
self, selector: str | type[ExpectType] | None = None
) -> DOMQuery[Widget] | DOMQuery[ExpectType]:
self, selector: str | type[QueryType] | None = None
) -> DOMQuery[Widget] | DOMQuery[QueryType]:
"""Get a DOM query matching a selector.
Args:
@@ -792,33 +790,31 @@ class DOMNode(MessagePump):
Returns:
A query object.
"""
from .css.query import DOMQuery
from .css.query import DOMQuery, QueryType
from .widget import Widget
query: str | None
if isinstance(selector, str) or selector is None:
query = selector
return DOMQuery[Widget](self, filter=selector)
else:
query = selector.__name__
return DOMQuery(self, filter=query)
return DOMQuery[QueryType](self, filter=selector.__name__)
@overload
def query_one(self, selector: str) -> Widget:
...
@overload
def query_one(self, selector: type[ExpectType]) -> ExpectType:
def query_one(self, selector: type[QueryType]) -> QueryType:
...
@overload
def query_one(self, selector: str, expect_type: type[ExpectType]) -> ExpectType:
def query_one(self, selector: str, expect_type: type[QueryType]) -> QueryType:
...
def query_one(
self,
selector: str | type[ExpectType],
expect_type: type[ExpectType] | None = None,
) -> ExpectType | Widget:
selector: str | type[QueryType],
expect_type: type[QueryType] | None = None,
) -> QueryType | Widget:
"""Get a single Widget matching the given selector or selector type.
Args:

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Type, TypeVar
import rich.repr
from rich.style import Style
from ._types import MessageTarget
from ._types import CallbackType, MessageTarget
from .geometry import Offset, Size
from .keys import _get_key_aliases
from .message import Message
@@ -28,11 +28,7 @@ class Event(Message):
@rich.repr.auto
class Callback(Event, bubble=False, verbose=True):
def __init__(
self,
sender: MessageTarget,
callback: Callable[[], Awaitable[None]],
) -> None:
def __init__(self, sender: MessageTarget, callback: CallbackType) -> None:
self.callback = callback
super().__init__(sender)

View File

@@ -5,7 +5,7 @@ from enum import Enum
# Adapted from prompt toolkit https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/keys.py
class Keys(str, Enum):
class Keys(str, Enum): # type: ignore[no-redef]
"""
List of keys for use in key bindings.
@@ -13,7 +13,9 @@ class Keys(str, Enum):
strings.
"""
value: str
@property
def value(self) -> str:
return super().value
Escape = "escape" # Also Control-[
ShiftEscape = "shift+escape"

View File

@@ -12,7 +12,7 @@ from ._compositor import Compositor, MapGeometry
from ._types import CallbackType
from .css.match import match
from .css.parse import parse_selectors
from .dom import DOMNode
from .css.query import QueryType
from .geometry import Offset, Region, Size
from .reactive import Reactive
from .renderables.blank import Blank
@@ -169,7 +169,7 @@ class Screen(Widget):
return widgets
def _move_focus(
self, direction: int = 0, selector: str | type[DOMNode.ExpectType] = "*"
self, direction: int = 0, selector: str | type[QueryType] = "*"
) -> Widget | None:
"""Move the focus in the given direction.
@@ -230,9 +230,7 @@ class Screen(Widget):
return self.focused
def focus_next(
self, selector: str | type[DOMNode.ExpectType] = "*"
) -> Widget | None:
def focus_next(self, selector: str | type[QueryType] = "*") -> Widget | None:
"""Focus the next widget, optionally filtered by a CSS selector.
If no widget is currently focused, this will focus the first focusable widget.
@@ -249,9 +247,7 @@ class Screen(Widget):
"""
return self._move_focus(1, selector)
def focus_previous(
self, selector: str | type[DOMNode.ExpectType] = "*"
) -> Widget | None:
def focus_previous(self, selector: str | type[QueryType] = "*") -> Widget | None:
"""Focus the previous widget, optionally filtered by a CSS selector.
If no widget is currently focused, this will focus the first focusable widget.

View File

@@ -300,13 +300,13 @@ class ScrollBar(Widget):
def _on_leave(self, event: events.Leave) -> None:
self.mouse_over = False
async def action_scroll_down(self) -> None:
await self.post_message(
def action_scroll_down(self) -> None:
self.post_message_no_wait(
ScrollDown(self) if self.vertical else ScrollRight(self)
)
async def action_scroll_up(self) -> None:
await self.post_message(ScrollUp(self) if self.vertical else ScrollLeft(self))
def action_scroll_up(self) -> None:
self.post_message_no_wait(ScrollUp(self) if self.vertical else ScrollLeft(self))
def action_grab(self) -> None:
self.capture_mouse()

View File

@@ -65,6 +65,7 @@ from .walk import walk_depth_first
if TYPE_CHECKING:
from .app import App, ComposeResult
from .message_pump import MessagePump
from .scrollbar import (
ScrollBar,
ScrollBarCorner,
@@ -443,23 +444,27 @@ class Widget(DOMNode):
self, id: str, expect_type: type[ExpectType] | None = None
) -> ExpectType | Widget:
"""Return the first descendant widget with the given ID.
Performs a depth-first search rooted at this widget.
Args:
id: The ID to search for in the subtree
id: The ID to search for in the subtree.
expect_type: Require the object be of the supplied type, or None for any type.
Defaults to None.
Returns:
The first descendant encountered with this ID.
Raises:
NoMatches: if no children could be found for this ID
NoMatches: if no children could be found for this ID.
WrongType: if the wrong type was found.
"""
for child in walk_depth_first(self):
# We use Widget as a filter_type so that the inferred type of child is Widget.
for child in walk_depth_first(self, filter_type=Widget):
try:
return child.get_child_by_id(id, expect_type=expect_type)
if expect_type is None:
return child.get_child_by_id(id)
else:
return child.get_child_by_id(id, expect_type=expect_type)
except NoMatches:
pass
except WrongType as exc:
@@ -729,7 +734,9 @@ class Widget(DOMNode):
# Ensure the child and target are widgets.
child = _to_widget(child, "move")
target = _to_widget(before if after is None else after, "move towards")
target = _to_widget(
cast("int | Widget", before if after is None else after), "move towards"
)
# At this point we should know what we're moving, and it should be a
# child; where we're moving it to, which should be within the child
@@ -2255,7 +2262,7 @@ class Widget(DOMNode):
Names of the pseudo classes.
"""
node = self
node: MessagePump | None = self
while isinstance(node, Widget):
if node.disabled:
yield "disabled"
@@ -2302,7 +2309,9 @@ class Widget(DOMNode):
renderable.justify = text_justify
renderable = _Styled(
renderable, self.rich_style, self.link_style if self.auto_links else None
cast(ConsoleRenderable, renderable),
self.rich_style,
self.link_style if self.auto_links else None,
)
return renderable
@@ -2504,7 +2513,7 @@ class Widget(DOMNode):
self.check_idle()
def remove(self) -> AwaitRemove:
"""Remove the Widget from the DOM (effectively deleting it)
"""Remove the Widget from the DOM (effectively deleting it).
Returns:
An awaitable object that waits for the widget to be removed.
@@ -2517,16 +2526,16 @@ class Widget(DOMNode):
"""Get renderable for widget.
Returns:
Any renderable
Any renderable.
"""
render = "" if self.is_container else self.css_identifier_styled
render: Text | str = "" if self.is_container else self.css_identifier_styled
return render
def _render(self) -> ConsoleRenderable | RichCast:
"""Get renderable, promoting str to text as required.
Returns:
A renderable
A renderable.
"""
renderable = self.render()
if isinstance(renderable, str):

View File

@@ -4,7 +4,6 @@ from functools import partial
from typing import cast
import rich.repr
from rich.console import RenderableType
from rich.text import Text, TextType
from typing_extensions import Literal
@@ -145,7 +144,7 @@ class Button(Static, can_focus=True):
ACTIVE_EFFECT_DURATION = 0.3
"""When buttons are clicked they get the `-active` class for this duration (in seconds)"""
label: reactive[RenderableType] = reactive[RenderableType]("")
label: reactive[TextType] = reactive[TextType]("")
"""The text label that appears within the button."""
variant = reactive("default")
@@ -209,15 +208,14 @@ class Button(Static, can_focus=True):
self.remove_class(f"-{old_variant}")
self.add_class(f"-{variant}")
def validate_label(self, label: RenderableType) -> RenderableType:
def validate_label(self, label: TextType) -> TextType:
"""Parse markup for self.label"""
if isinstance(label, str):
return Text.from_markup(label)
return label
def render(self) -> RenderableType:
label = self.label.copy()
label = Text.assemble(" ", label, " ")
def render(self) -> TextType:
label = Text.assemble(" ", self.label, " ")
label.stylize(self.text_style)
return label

View File

@@ -66,7 +66,7 @@ class Footer(Widget):
self.refresh()
def on_mount(self) -> None:
self.watch(self.screen, "focused", self._focus_changed)
self.watch(self.screen, "focused", self._focus_changed) # type: ignore[arg-type]
def _focus_changed(self, focused: Widget | None) -> None:
self._key_text = None

View File

@@ -166,5 +166,5 @@ class Header(Widget):
def set_sub_title(sub_title: str) -> None:
self.query_one(HeaderTitle).sub_text = sub_title
self.watch(self.app, "title", set_title)
self.watch(self.app, "sub_title", set_sub_title)
self.watch(self.app, "title", set_title) # type: ignore[arg-type]
self.watch(self.app, "sub_title", set_sub_title) # type: ignore[arg-type]

View File

@@ -40,6 +40,8 @@ class ListItem(Widget, can_focus=False):
class _ChildClicked(Message):
"""For informing with the parent ListView that we were clicked"""
sender: "ListItem"
def on_click(self, event: events.Click) -> None:
self.post_message_no_wait(self._ChildClicked(self))

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import ClassVar
from typing import ClassVar, Optional
from textual.await_remove import AwaitRemove
from textual.binding import Binding, BindingType
@@ -8,7 +8,7 @@ from textual.containers import Vertical
from textual.geometry import clamp
from textual.message import Message
from textual.reactive import reactive
from textual.widget import AwaitMount
from textual.widget import AwaitMount, Widget
from textual.widgets._list_item import ListItem
@@ -35,7 +35,7 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
| down | Move the cursor down. |
"""
index = reactive(0, always_update=True)
index = reactive[Optional[int]](0, always_update=True)
class Highlighted(Message, bubble=True):
"""Posted when the highlighted item changes.
@@ -96,10 +96,12 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
@property
def highlighted_child(self) -> ListItem | None:
"""The currently highlighted ListItem, or None if nothing is highlighted."""
if self.index is None:
if self.index is not None and 0 <= self.index < len(self._nodes):
list_item = self._nodes[self.index]
assert isinstance(list_item, ListItem)
return list_item
else:
return None
elif 0 <= self.index < len(self._nodes):
return self._nodes[self.index]
def validate_index(self, index: int | None) -> int | None:
"""Clamp the index to the valid range, or set to None if there's nothing to highlight.
@@ -129,9 +131,13 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
"""Updates the highlighting when the index changes."""
if self._is_valid_index(old_index):
old_child = self._nodes[old_index]
assert isinstance(old_child, ListItem)
old_child.highlighted = False
new_child: Widget | None
if self._is_valid_index(new_index):
new_child = self._nodes[new_index]
assert isinstance(new_child, ListItem)
new_child.highlighted = True
else:
new_child = None
@@ -168,14 +174,22 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
def action_select_cursor(self) -> None:
"""Select the current item in the list."""
selected_child = self.highlighted_child
if selected_child is None:
return
self.post_message_no_wait(self.Selected(self, selected_child))
def action_cursor_down(self) -> None:
"""Highlight the next item in the list."""
if self.index is None:
self.index = 0
return
self.index += 1
def action_cursor_up(self) -> None:
"""Highlight the previous item in the list."""
if self.index is None:
self.index = 0
return
self.index -= 1
def on_list_item__child_clicked(self, event: ListItem._ChildClicked) -> None: