diff --git a/docs/guide/events.md b/docs/guide/events.md index 9cdfb3259..51821aeca 100644 --- a/docs/guide/events.md +++ b/docs/guide/events.md @@ -183,7 +183,7 @@ Let's look at an example which looks up word definitions from an [api](https://d === "Output" - ```{.textual path="docs/examples/events/dictionary.py" press="tab,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_"} + ```{.textual path="docs/examples/events/dictionary.py" press="t,e,x,t,_,_,_,_,_,_,_,_,_,_,_"} ``` Note the highlighted line in the above code which calls `asyncio.create_task` to run a coroutine in the background. Without this you would find typing in to the text box to be unresponsive. diff --git a/docs/index.md b/docs/index.md index 66e83adb7..eac77d1b7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -70,7 +70,7 @@ Textual is a framework for building applications that run within your terminal. ``` -```{.textual path="docs/examples/events/dictionary.py" columns="100" lines="30" press="tab,_,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_,_,_"} +```{.textual path="docs/examples/events/dictionary.py" columns="100" lines="30" press="_,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_,_,_"} ``` diff --git a/examples/code_browser.py b/examples/code_browser.py index 86b54665a..90303a5d4 100644 --- a/examples/code_browser.py +++ b/examples/code_browser.py @@ -30,7 +30,7 @@ class CodeBrowser(App): def watch_show_tree(self, show_tree: bool) -> None: """Called when show_tree is modified.""" - self.set_class(show_tree, "-show-tree") + self.set_class(show_tree, "-show-tree") def compose(self) -> ComposeResult: """Compose our UI.""" diff --git a/examples/dictionary.css b/examples/dictionary.css index 8850249c4..6bca8b9f5 100644 --- a/examples/dictionary.css +++ b/examples/dictionary.css @@ -15,7 +15,7 @@ Input { #results-container { background: $background 50%; - margin: 0; + margin: 0 0 1 0; height: 100%; overflow: hidden auto; border: tall $background; diff --git a/examples/dictionary.py b/examples/dictionary.py index 56c986371..737bcb283 100644 --- a/examples/dictionary.py +++ b/examples/dictionary.py @@ -23,6 +23,11 @@ class DictionaryApp(App): yield Input(placeholder="Search for a word") yield Content(Static(id="results"), id="results-container") + def on_mount(self) -> None: + """Called when app starts.""" + # Give the input focus, so we can start typing straight away + self.query_one(Input).focus() + async def on_input_changed(self, message: Input.Changed) -> None: """A coroutine to handle a text changed message.""" if message.value: diff --git a/sandbox/will/mount.py b/sandbox/will/mount.py new file mode 100644 index 000000000..e64b7c746 --- /dev/null +++ b/sandbox/will/mount.py @@ -0,0 +1,35 @@ +from textual.app import App, ComposeResult + +from textual.containers import Container +from textual.widget import Widget +from textual.widgets import Static + + +class MountWidget(Widget): + def on_mount(self) -> None: + print("Widget mounted") + + +class MountContainer(Container): + def compose(self) -> ComposeResult: + yield Container(MountWidget(id="bar")) + + def on_mount(self) -> None: + bar = self.query_one("#bar") + print("MountContainer got", bar) + + +class MountApp(App): + def compose(self) -> ComposeResult: + yield MountContainer(id="foo") + + def on_mount(self) -> None: + foo = self.query_one("#foo") + print("foo is", foo) + static = self.query_one("#bar") + print("App got", static) + + +if __name__ == "__main__": + app = MountApp() + app.run() diff --git a/sandbox/will/screens.py b/sandbox/will/screens.py index 3a9f34dc2..8f5cfd2fc 100644 --- a/sandbox/will/screens.py +++ b/sandbox/will/screens.py @@ -7,6 +7,9 @@ class ModalScreen(Screen): yield Pretty(self.app.screen_stack) yield Footer() + def on_mount(self) -> None: + pretty = self.query_one("Pretty") + def on_screen_resume(self): self.query_one(Pretty).update(self.app.screen_stack) diff --git a/src/textual/app.py b/src/textual/app.py index 4b618f0e1..ef5d05584 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -37,7 +37,7 @@ from .css.stylesheet import Stylesheet from .design import ColorSystem from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog from .devtools.redirect_output import StdoutRedirector -from .dom import DOMNode, NoScreen +from .dom import DOMNode from .driver import Driver from .drivers.headless_driver import HeadlessDriver from .features import FeatureFlag, parse_features @@ -47,7 +47,7 @@ from .messages import CallbackType from .reactive import Reactive from .renderables.blank import Blank from .screen import Screen -from .widget import Widget +from .widget import AwaitMount, Widget PLATFORM = platform.system() WINDOWS = PLATFORM == "Windows" @@ -144,6 +144,10 @@ class App(Generic[ReturnType], DOMNode): _BASE_PATH: str | None = None CSS_PATH: CSSPathType = None + title: Reactive[str] = Reactive("Textual") + sub_title: Reactive[str] = Reactive("") + dark: Reactive[bool] = Reactive(True) + def __init__( self, driver_class: Type[Driver] | None = None, @@ -228,10 +232,6 @@ class App(Generic[ReturnType], DOMNode): ) self._screenshot: str | None = None - title: Reactive[str] = Reactive("Textual") - sub_title: Reactive[str] = Reactive("") - dark: Reactive[bool] = Reactive(True) - def animate( self, attribute: str, @@ -696,21 +696,24 @@ class App(Generic[ReturnType], DOMNode): self._require_stylesheet_update.add(self.screen if node is None else node) self.check_idle() - def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: + def mount(self, *anon_widgets: Widget, **widgets: Widget) -> AwaitMount: """Mount widgets. Widgets specified as positional args, or keywords args. If supplied as keyword args they will be assigned an id of the key. """ - self._register(self.screen, *anon_widgets, **widgets) + mounted_widgets = self._register(self.screen, *anon_widgets, **widgets) + return AwaitMount(mounted_widgets) - def mount_all(self, widgets: Iterable[Widget]) -> None: + def mount_all(self, widgets: Iterable[Widget]) -> AwaitMount: """Mount widgets from an iterable. Args: widgets (Iterable[Widget]): An iterable of widgets. """ - for widget in widgets: + mounted_widgets = list(widgets) + for widget in mounted_widgets: self._register(self.screen, widget) + return AwaitMount(mounted_widgets) def is_screen_installed(self, screen: Screen | str) -> bool: """Check if a given screen has been installed. @@ -1008,23 +1011,36 @@ class App(Generic[ReturnType], DOMNode): self.set_interval(0.25, self.css_monitor, name="css monitor") self.log.system("[b green]STARTED[/]", self.css_monitor) - process_messages = super()._process_messages - async def run_process_messages(): - compose_event = events.Compose(sender=self) - await self._dispatch_message(compose_event) - mount_event = events.Mount(sender=self) - await self._dispatch_message(mount_event) - Reactive.initialize_object(self) + try: + await self._dispatch_message(events.Compose(sender=self)) + await self._dispatch_message(events.Mount(sender=self)) + finally: + self._mounted_event.set() + + Reactive._initialize_object(self) + self.title = self._title self.stylesheet.update(self) self.refresh() + await self.animator.start() await self._ready() if ready_callback is not None: await ready_callback() - await process_messages() + + self._running = True + + try: + await self._process_messages_loop() + except asyncio.CancelledError: + pass + finally: + self._running = False + for timer in list(self._timers): + await timer.stop() + await self.animator.stop() await self._close_all() @@ -1059,6 +1075,9 @@ class App(Generic[ReturnType], DOMNode): if self.devtools.is_connected: await self._disconnect_devtools() + async def _pre_process(self) -> None: + pass + async def _ready(self) -> None: """Called immediately prior to processing messages. @@ -1083,9 +1102,9 @@ class App(Generic[ReturnType], DOMNode): self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer") - def _on_compose(self) -> None: - widgets = self.compose() - self.mount_all(widgets) + async def _on_compose(self) -> None: + widgets = list(self.compose()) + await self.mount_all(widgets) def _on_idle(self) -> None: """Perform actions when there are no messages in the queue.""" @@ -1110,15 +1129,15 @@ class App(Generic[ReturnType], DOMNode): def _register( self, parent: DOMNode, *anon_widgets: Widget, **widgets: Widget - ) -> None: + ) -> list[Widget]: """Mount widget(s) so they may receive events. Args: parent (Widget): Parent Widget """ if not anon_widgets and not widgets: - return - name_widgets: Iterable[tuple[str | None, Widget]] + return [] + name_widgets: list[tuple[str | None, Widget]] name_widgets = [*((None, widget) for widget in anon_widgets), *widgets.items()] apply_stylesheet = self.stylesheet.apply @@ -1133,8 +1152,8 @@ class App(Generic[ReturnType], DOMNode): self._register(widget, *widget.children) apply_stylesheet(widget) - for _widget_id, widget in name_widgets: - widget.post_message_no_wait(events.Mount(sender=parent)) + registered_widgets = [widget for _, widget in name_widgets] + return registered_widgets def _unregister(self, widget: Widget) -> None: """Unregister a widget. @@ -1142,11 +1161,7 @@ class App(Generic[ReturnType], DOMNode): Args: widget (Widget): A Widget to unregister """ - try: - widget.screen._reset_focus(widget) - except NoScreen: - pass - + widget.reset_focus() if isinstance(widget._parent, Widget): widget._parent.children._remove(widget) widget._detach() @@ -1164,9 +1179,16 @@ class App(Generic[ReturnType], DOMNode): """ widget._attach(parent) widget._start_messages() - widget.post_message_no_wait(events.Mount(sender=parent)) def is_mounted(self, widget: Widget) -> bool: + """Check if a widget is mounted. + + Args: + widget (Widget): A widget. + + Returns: + bool: True of the widget is mounted. + """ return widget in self._registry async def _close_all(self) -> None: @@ -1388,22 +1410,22 @@ class App(Generic[ReturnType], DOMNode): async def _on_resize(self, event: events.Resize) -> None: event.stop() - self.screen._screen_resized(event.size) await self.screen.post_message(event) async def _on_remove(self, event: events.Remove) -> None: widget = event.widget parent = widget.parent - if parent is not None: - parent.refresh(layout=True) + + widget.reset_focus() remove_widgets = widget.walk_children( Widget, with_self=True, method="depth", reverse=True ) - for child in remove_widgets: - self._unregister(child) for child in remove_widgets: await child._close_messages() + self._unregister(child) + if parent is not None: + parent.refresh(layout=True) async def action_press(self, key: str) -> None: await self.press(key) diff --git a/src/textual/cli/previews/colors.py b/src/textual/cli/previews/colors.py index af3e26262..20e2fa250 100644 --- a/src/textual/cli/previews/colors.py +++ b/src/textual/cli/previews/colors.py @@ -68,9 +68,16 @@ class ColorsApp(App): BINDINGS = [("d", "toggle_dark", "Toggle dark mode")] def compose(self) -> ComposeResult: - yield Content(ColorButtons(), ColorsView()) + yield Content(ColorButtons()) yield Footer() + def on_mount(self) -> None: + self.call_later(self.update_view) + + def update_view(self) -> None: + content = self.query_one("Content", Content) + content.mount(ColorsView()) + def on_button_pressed(self, event: Button.Pressed) -> None: self.query(ColorGroup).remove_class("-active") group = self.query_one(f"#group-{event.button.id}", ColorGroup) diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index 0045f8a67..611539358 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -249,7 +249,7 @@ class Stylesheet: css = css_file.read() path = os.path.abspath(filename) except Exception as error: - raise StylesheetError(f"unable to read {filename!r}; {error}") + raise StylesheetError(f"unable to read CSS file {filename!r}") from None self.source[str(path)] = CssSource(css, False, 0) self._require_parse = True diff --git a/src/textual/dom.py b/src/textual/dom.py index cc3f0866e..773c1cc9f 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -662,7 +662,7 @@ class DOMNode(MessagePump): Defaults to None. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth". - reverse (bool, optional): Reverse the order (bottom up). Defaults to False + reverse (bool, optional): Reverse the order (bottom up). Defaults to False. Returns: Iterable[DOMNode | WalkType]: An iterable of nodes. diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 23bd8a873..6470b7e78 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -73,6 +73,7 @@ class MessagePump(metaclass=MessagePumpMeta): self._timers: WeakSet[Timer] = WeakSet() self._last_idle: float = time() self._max_idle: float | None = None + self._mounted_event = asyncio.Event() @property def task(self) -> Task: @@ -278,18 +279,19 @@ class MessagePump(metaclass=MessagePumpMeta): await timer.stop() self._timers.clear() await self._message_queue.put(None) - if self._task is not None and asyncio.current_task() != self._task: # Ensure everything is closed before returning await self._task def _start_messages(self) -> None: """Start messages task.""" - Reactive.initialize_object(self) self._task = asyncio.create_task(self._process_messages()) async def _process_messages(self) -> None: self._running = True + + await self._pre_process() + try: await self._process_messages_loop() except CancelledError: @@ -299,6 +301,18 @@ class MessagePump(metaclass=MessagePumpMeta): for timer in list(self._timers): await timer.stop() + async def _pre_process(self) -> None: + """Procedure to run before processing messages.""" + # Dispatch compose and mount messages without going through loop + # These events must occur in this order, and at the start. + try: + await self._dispatch_message(events.Compose(sender=self)) + await self._dispatch_message(events.Mount(sender=self)) + finally: + # This is critical, mount may be waiting + self._mounted_event.set() + Reactive._initialize_object(self) + async def _process_messages_loop(self) -> None: """Process messages until the queue is closed.""" _rich_traceback_guard = True @@ -331,11 +345,15 @@ class MessagePump(metaclass=MessagePumpMeta): except CancelledError: raise except Exception as error: + self._mounted_event.set() self.app._handle_exception(error) break finally: + self._message_queue.task_done() current_time = time() + + # Insert idle events if self._message_queue.empty() or ( self._max_idle is not None and current_time - self._last_idle > self._max_idle diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 74c962b45..c237955b4 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -20,6 +20,12 @@ if TYPE_CHECKING: ReactiveType = TypeVar("ReactiveType") +class _NotSet: + pass + + +_NOT_SET = _NotSet() + T = TypeVar("T") @@ -83,24 +89,31 @@ class Reactive(Generic[ReactiveType]): return cls(default, layout=False, repaint=False, init=True) @classmethod - def initialize_object(cls, obj: object) -> None: - """Call any watchers / computes for the first time. + def _initialize_object(cls, obj: object) -> None: + """Set defaults and call any watchers / computes for the first time. Args: obj (Reactable): An object with Reactive descriptors """ - - startswith = str.startswith - for key in obj.__class__.__dict__.keys(): - if startswith(key, "_init_"): - name = key[6:] - if not hasattr(obj, name): - default = getattr(obj, key) - setattr(obj, name, default() if callable(default) else default) + if not hasattr(obj, "__reactive_initialized"): + startswith = str.startswith + for key in obj.__class__.__dict__: + if startswith(key, "_default_"): + name = key[9:] + # Check defaults + if not hasattr(obj, name): + # Attribute has no value yet + default = getattr(obj, key) + default_value = default() if callable(default) else default + # Set the default vale (calls `__set__`) + setattr(obj, name, default_value) + setattr(obj, "__reactive_initialized", True) def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: + # Check for compute method if hasattr(owner, f"compute_{name}"): + # Compute methods are stored in a list called `__computes` try: computes = getattr(owner, "__computes") except AttributeError: @@ -108,31 +121,46 @@ class Reactive(Generic[ReactiveType]): setattr(owner, "__computes", computes) computes.append(name) + # The name of the attribute self.name = name + # The internal name where the attribute's value is stored self.internal_name = f"_reactive_{name}" default = self._default - - if self._init: - setattr(owner, f"_init_{name}", default) - else: - setattr( - owner, self.internal_name, default() if callable(default) else default - ) + setattr(owner, f"_default_{name}", default) def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: - return getattr(obj, self.internal_name) + value: _NotSet | ReactiveType = getattr(obj, self.internal_name, _NOT_SET) + if isinstance(value, _NotSet): + # No value present, we need to set the default + init_name = f"_default_{self.name}" + default = getattr(obj, init_name) + default_value = default() if callable(default) else default + # Set and return the value + setattr(obj, self.internal_name, default_value) + if self._init: + self._check_watchers(obj, self.name, default_value, first_set=True) + return default_value + return value def __set__(self, obj: Reactable, value: ReactiveType) -> None: name = self.name - current_value = getattr(obj, self.internal_name, None) + current_value = getattr(obj, name) + # Check for validate function validate_function = getattr(obj, f"validate_{name}", None) + # Check if this is the first time setting the value first_set = getattr(obj, f"__first_set_{self.internal_name}", True) + # Call validate, but not on first set. if callable(validate_function) and not first_set: value = validate_function(value) + # If the value has changed, or this is the first time setting the value if current_value != value or first_set: + # Set the first set flag to False setattr(obj, f"__first_set_{self.internal_name}", False) + # Store the internal value setattr(obj, self.internal_name, value) + # Check all watchers self._check_watchers(obj, name, current_value, first_set=first_set) + # Refresh according to descriptor flags if self._layout or self._repaint: obj.refresh(repaint=self._repaint, layout=self._layout) @@ -140,50 +168,77 @@ class Reactive(Generic[ReactiveType]): def _check_watchers( cls, obj: Reactable, name: str, old_value: Any, first_set: bool = False ) -> None: + """Check watchers, and call watch methods / computes + Args: + obj (Reactable): The reactable object. + name (str): Attribute name. + old_value (Any): The old (previous) value of the attribute. + first_set (bool, optional): True if this is the first time setting the value. Defaults to False. + """ + # Get the current value. internal_name = f"_reactive_{name}" value = getattr(obj, internal_name) async def update_watcher( obj: Reactable, watch_function: Callable, old_value: Any, value: Any ) -> None: + """Call watch function, and run compute. + + Args: + obj (Reactable): Reactable object. + watch_function (Callable): Watch method. + old_value (Any): Old value. + value (Any): new value. + """ _rich_traceback_guard = True + # Call watch with one or two parameters if count_parameters(watch_function) == 2: watch_result = watch_function(old_value, value) else: watch_result = watch_function(value) + # Optionally await result if isawaitable(watch_result): await watch_result + # Run computes await Reactive._compute(obj) + # Check for watch method watch_function = getattr(obj, f"watch_{name}", None) if callable(watch_function): + # Post a callback message, so we can call the watch method in an orderly async manner obj.post_message_no_wait( events.Callback( - obj, + sender=obj, callback=partial( update_watcher, obj, watch_function, old_value, value ), ) ) + # Check for watchers set via `watch` watcher_name = f"__{name}_watchers" watchers = getattr(obj, watcher_name, ()) for watcher in watchers: obj.post_message_no_wait( events.Callback( - obj, + sender=obj, callback=partial(update_watcher, obj, watcher, old_value, value), ) ) - if not first_set: - obj.post_message_no_wait( - events.Callback(obj, callback=partial(Reactive._compute, obj)) - ) + # Run computes + obj.post_message_no_wait( + events.Callback(sender=obj, callback=partial(Reactive._compute, obj)) + ) @classmethod async def _compute(cls, obj: Reactable) -> None: + """Invoke all computes. + + Args: + obj (Reactable): Reactable object. + """ _rich_traceback_guard = True computes = getattr(obj, "__computes", []) for compute in computes: diff --git a/src/textual/screen.py b/src/textual/screen.py index c3a06d6c0..d32b2dda5 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -387,12 +387,12 @@ class Screen(Widget): def _on_screen_resume(self) -> None: """Called by the App""" - size = self.app.size self._refresh_layout(size, full=True) async def _on_resize(self, event: events.Resize) -> None: event.stop() + self._screen_resized(event.size) async def _handle_mouse_move(self, event: events.MouseMove) -> None: try: diff --git a/src/textual/widget.py b/src/textual/widget.py index 23858b64e..31f6209f7 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -1,10 +1,20 @@ from __future__ import annotations -from asyncio import Lock +from asyncio import Lock, wait, create_task from fractions import Fraction from itertools import islice from operator import attrgetter -from typing import TYPE_CHECKING, ClassVar, Collection, Iterable, NamedTuple, cast +from typing import ( + Awaitable, + Generator, + TYPE_CHECKING, + ClassVar, + Collection, + Iterable, + NamedTuple, + Sequence, + cast, +) import rich.repr from rich.console import ( @@ -59,6 +69,28 @@ _JUSTIFY_MAP: dict[str, JustifyMethod] = { } +class AwaitMount: + """An awaitable returned by mount() and mount_all(). + + Example: + await self.mount(Static("foo")) + + """ + + def __init__(self, widgets: Sequence[Widget]) -> None: + self._widgets = widgets + + def __await__(self) -> Generator[None, None, None]: + async def await_mount() -> None: + aws = [ + create_task(widget._mounted_event.wait()) for widget in self._widgets + ] + if aws: + await wait(aws) + + return await_mount().__await__() + + class _Styled: """Apply a style to a renderable. @@ -316,7 +348,7 @@ class Widget(DOMNode): """Clear arrangement cache, forcing a new arrange operation.""" self._arrangement = None - def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: + def mount(self, *anon_widgets: Widget, **widgets: Widget) -> AwaitMount: """Mount child widgets (making this widget a container). Widgets may be passed as positional arguments or keyword arguments. If keyword arguments, @@ -327,9 +359,12 @@ class Widget(DOMNode): self.mount(Static("hello"), header=Header()) ``` + Returns: + AwaitMount: An awaitable that waits for widgets to be mounted. + """ - self.app._register(self, *anon_widgets, **widgets) - self.app.screen.refresh(layout=True) + mounted_widgets = self.app._register(self, *anon_widgets, **widgets) + return AwaitMount(mounted_widgets) def compose(self) -> ComposeResult: """Called by Textual to create child widgets. @@ -1812,7 +1847,19 @@ class Widget(DOMNode): scroll_visible (bool, optional): Scroll parent to make this widget visible. Defaults to True. """ - self.screen.set_focus(self, scroll_visible=scroll_visible) + + def set_focus(widget: Widget): + """Callback to set the focus.""" + widget.screen.set_focus(self, scroll_visible=scroll_visible) + + self.app.call_later(set_focus, self) + + def reset_focus(self) -> None: + """Reset the focus (move it to the next available widget).""" + try: + self.screen._reset_focus(self) + except NoScreen: + pass def capture_mouse(self, capture: bool = True) -> None: """Capture (or release) the mouse. @@ -1857,10 +1904,11 @@ class Widget(DOMNode): await self.action(binding.action) return True + async def _on_compose(self, event: events.Compose) -> None: + widgets = list(self.compose()) + await self.mount(*widgets) + def _on_mount(self, event: events.Mount) -> None: - widgets = self.compose() - self.mount(*widgets) - # Preset scrollbars if not automatic if self.styles.overflow_y == "scroll": self.show_vertical_scrollbar = True if self.styles.overflow_x == "scroll": @@ -1932,7 +1980,7 @@ class Widget(DOMNode): def _on_hide(self, event: events.Hide) -> None: if self.has_focus: - self.screen._reset_focus(self) + self.reset_focus() def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None: self.scroll_to_region(message.region, animate=True)