Merge branch 'mount-wait' into roadmap

This commit is contained in:
Will McGugan
2022-10-16 13:51:25 +01:00
15 changed files with 275 additions and 82 deletions

View File

@@ -183,7 +183,7 @@ Let's look at an example which looks up word definitions from an [api](https://d
=== "Output" === "Output"
```{.textual path="docs/examples/events/dictionary.py" press="tab,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_"} ```{.textual path="docs/examples/events/dictionary.py" press="t,e,x,t,_,_,_,_,_,_,_,_,_,_,_"}
``` ```
Note the highlighted line in the above code which calls `asyncio.create_task` to run a coroutine in the background. Without this you would find typing in to the text box to be unresponsive. Note the highlighted line in the above code which calls `asyncio.create_task` to run a coroutine in the background. Without this you would find typing in to the text box to be unresponsive.

View File

@@ -70,7 +70,7 @@ Textual is a framework for building applications that run within your terminal.
``` ```
```{.textual path="docs/examples/events/dictionary.py" columns="100" lines="30" press="tab,_,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_,_,_"} ```{.textual path="docs/examples/events/dictionary.py" columns="100" lines="30" press="_,t,e,x,t,_,_,_,_,_,_,_,_,_,_,_,_,_"}
``` ```

View File

@@ -30,7 +30,7 @@ class CodeBrowser(App):
def watch_show_tree(self, show_tree: bool) -> None: def watch_show_tree(self, show_tree: bool) -> None:
"""Called when show_tree is modified.""" """Called when show_tree is modified."""
self.set_class(show_tree, "-show-tree") self.set_class(show_tree, "-show-tree")
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Compose our UI.""" """Compose our UI."""

View File

@@ -15,7 +15,7 @@ Input {
#results-container { #results-container {
background: $background 50%; background: $background 50%;
margin: 0; margin: 0 0 1 0;
height: 100%; height: 100%;
overflow: hidden auto; overflow: hidden auto;
border: tall $background; border: tall $background;

View File

@@ -23,6 +23,11 @@ class DictionaryApp(App):
yield Input(placeholder="Search for a word") yield Input(placeholder="Search for a word")
yield Content(Static(id="results"), id="results-container") yield Content(Static(id="results"), id="results-container")
def on_mount(self) -> None:
"""Called when app starts."""
# Give the input focus, so we can start typing straight away
self.query_one(Input).focus()
async def on_input_changed(self, message: Input.Changed) -> None: async def on_input_changed(self, message: Input.Changed) -> None:
"""A coroutine to handle a text changed message.""" """A coroutine to handle a text changed message."""
if message.value: if message.value:

35
sandbox/will/mount.py Normal file
View File

@@ -0,0 +1,35 @@
from textual.app import App, ComposeResult
from textual.containers import Container
from textual.widget import Widget
from textual.widgets import Static
class MountWidget(Widget):
def on_mount(self) -> None:
print("Widget mounted")
class MountContainer(Container):
def compose(self) -> ComposeResult:
yield Container(MountWidget(id="bar"))
def on_mount(self) -> None:
bar = self.query_one("#bar")
print("MountContainer got", bar)
class MountApp(App):
def compose(self) -> ComposeResult:
yield MountContainer(id="foo")
def on_mount(self) -> None:
foo = self.query_one("#foo")
print("foo is", foo)
static = self.query_one("#bar")
print("App got", static)
if __name__ == "__main__":
app = MountApp()
app.run()

View File

@@ -7,6 +7,9 @@ class ModalScreen(Screen):
yield Pretty(self.app.screen_stack) yield Pretty(self.app.screen_stack)
yield Footer() yield Footer()
def on_mount(self) -> None:
pretty = self.query_one("Pretty")
def on_screen_resume(self): def on_screen_resume(self):
self.query_one(Pretty).update(self.app.screen_stack) self.query_one(Pretty).update(self.app.screen_stack)

View File

@@ -37,7 +37,7 @@ from .css.stylesheet import Stylesheet
from .design import ColorSystem from .design import ColorSystem
from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog
from .devtools.redirect_output import StdoutRedirector from .devtools.redirect_output import StdoutRedirector
from .dom import DOMNode, NoScreen from .dom import DOMNode
from .driver import Driver from .driver import Driver
from .drivers.headless_driver import HeadlessDriver from .drivers.headless_driver import HeadlessDriver
from .features import FeatureFlag, parse_features from .features import FeatureFlag, parse_features
@@ -47,7 +47,7 @@ from .messages import CallbackType
from .reactive import Reactive from .reactive import Reactive
from .renderables.blank import Blank from .renderables.blank import Blank
from .screen import Screen from .screen import Screen
from .widget import Widget from .widget import AwaitMount, Widget
PLATFORM = platform.system() PLATFORM = platform.system()
WINDOWS = PLATFORM == "Windows" WINDOWS = PLATFORM == "Windows"
@@ -144,6 +144,10 @@ class App(Generic[ReturnType], DOMNode):
_BASE_PATH: str | None = None _BASE_PATH: str | None = None
CSS_PATH: CSSPathType = None CSS_PATH: CSSPathType = None
title: Reactive[str] = Reactive("Textual")
sub_title: Reactive[str] = Reactive("")
dark: Reactive[bool] = Reactive(True)
def __init__( def __init__(
self, self,
driver_class: Type[Driver] | None = None, driver_class: Type[Driver] | None = None,
@@ -228,10 +232,6 @@ class App(Generic[ReturnType], DOMNode):
) )
self._screenshot: str | None = None self._screenshot: str | None = None
title: Reactive[str] = Reactive("Textual")
sub_title: Reactive[str] = Reactive("")
dark: Reactive[bool] = Reactive(True)
def animate( def animate(
self, self,
attribute: str, attribute: str,
@@ -696,21 +696,24 @@ class App(Generic[ReturnType], DOMNode):
self._require_stylesheet_update.add(self.screen if node is None else node) self._require_stylesheet_update.add(self.screen if node is None else node)
self.check_idle() self.check_idle()
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: def mount(self, *anon_widgets: Widget, **widgets: Widget) -> AwaitMount:
"""Mount widgets. Widgets specified as positional args, or keywords args. If supplied """Mount widgets. Widgets specified as positional args, or keywords args. If supplied
as keyword args they will be assigned an id of the key. as keyword args they will be assigned an id of the key.
""" """
self._register(self.screen, *anon_widgets, **widgets) mounted_widgets = self._register(self.screen, *anon_widgets, **widgets)
return AwaitMount(mounted_widgets)
def mount_all(self, widgets: Iterable[Widget]) -> None: def mount_all(self, widgets: Iterable[Widget]) -> AwaitMount:
"""Mount widgets from an iterable. """Mount widgets from an iterable.
Args: Args:
widgets (Iterable[Widget]): An iterable of widgets. widgets (Iterable[Widget]): An iterable of widgets.
""" """
for widget in widgets: mounted_widgets = list(widgets)
for widget in mounted_widgets:
self._register(self.screen, widget) self._register(self.screen, widget)
return AwaitMount(mounted_widgets)
def is_screen_installed(self, screen: Screen | str) -> bool: def is_screen_installed(self, screen: Screen | str) -> bool:
"""Check if a given screen has been installed. """Check if a given screen has been installed.
@@ -1008,23 +1011,36 @@ class App(Generic[ReturnType], DOMNode):
self.set_interval(0.25, self.css_monitor, name="css monitor") self.set_interval(0.25, self.css_monitor, name="css monitor")
self.log.system("[b green]STARTED[/]", self.css_monitor) self.log.system("[b green]STARTED[/]", self.css_monitor)
process_messages = super()._process_messages
async def run_process_messages(): async def run_process_messages():
compose_event = events.Compose(sender=self)
await self._dispatch_message(compose_event)
mount_event = events.Mount(sender=self)
await self._dispatch_message(mount_event)
Reactive.initialize_object(self) try:
await self._dispatch_message(events.Compose(sender=self))
await self._dispatch_message(events.Mount(sender=self))
finally:
self._mounted_event.set()
Reactive._initialize_object(self)
self.title = self._title self.title = self._title
self.stylesheet.update(self) self.stylesheet.update(self)
self.refresh() self.refresh()
await self.animator.start() await self.animator.start()
await self._ready() await self._ready()
if ready_callback is not None: if ready_callback is not None:
await ready_callback() await ready_callback()
await process_messages()
self._running = True
try:
await self._process_messages_loop()
except asyncio.CancelledError:
pass
finally:
self._running = False
for timer in list(self._timers):
await timer.stop()
await self.animator.stop() await self.animator.stop()
await self._close_all() await self._close_all()
@@ -1059,6 +1075,9 @@ class App(Generic[ReturnType], DOMNode):
if self.devtools.is_connected: if self.devtools.is_connected:
await self._disconnect_devtools() await self._disconnect_devtools()
async def _pre_process(self) -> None:
pass
async def _ready(self) -> None: async def _ready(self) -> None:
"""Called immediately prior to processing messages. """Called immediately prior to processing messages.
@@ -1083,9 +1102,9 @@ class App(Generic[ReturnType], DOMNode):
self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer") self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer")
def _on_compose(self) -> None: async def _on_compose(self) -> None:
widgets = self.compose() widgets = list(self.compose())
self.mount_all(widgets) await self.mount_all(widgets)
def _on_idle(self) -> None: def _on_idle(self) -> None:
"""Perform actions when there are no messages in the queue.""" """Perform actions when there are no messages in the queue."""
@@ -1110,15 +1129,15 @@ class App(Generic[ReturnType], DOMNode):
def _register( def _register(
self, parent: DOMNode, *anon_widgets: Widget, **widgets: Widget self, parent: DOMNode, *anon_widgets: Widget, **widgets: Widget
) -> None: ) -> list[Widget]:
"""Mount widget(s) so they may receive events. """Mount widget(s) so they may receive events.
Args: Args:
parent (Widget): Parent Widget parent (Widget): Parent Widget
""" """
if not anon_widgets and not widgets: if not anon_widgets and not widgets:
return return []
name_widgets: Iterable[tuple[str | None, Widget]] name_widgets: list[tuple[str | None, Widget]]
name_widgets = [*((None, widget) for widget in anon_widgets), *widgets.items()] name_widgets = [*((None, widget) for widget in anon_widgets), *widgets.items()]
apply_stylesheet = self.stylesheet.apply apply_stylesheet = self.stylesheet.apply
@@ -1133,8 +1152,8 @@ class App(Generic[ReturnType], DOMNode):
self._register(widget, *widget.children) self._register(widget, *widget.children)
apply_stylesheet(widget) apply_stylesheet(widget)
for _widget_id, widget in name_widgets: registered_widgets = [widget for _, widget in name_widgets]
widget.post_message_no_wait(events.Mount(sender=parent)) return registered_widgets
def _unregister(self, widget: Widget) -> None: def _unregister(self, widget: Widget) -> None:
"""Unregister a widget. """Unregister a widget.
@@ -1142,11 +1161,7 @@ class App(Generic[ReturnType], DOMNode):
Args: Args:
widget (Widget): A Widget to unregister widget (Widget): A Widget to unregister
""" """
try: widget.reset_focus()
widget.screen._reset_focus(widget)
except NoScreen:
pass
if isinstance(widget._parent, Widget): if isinstance(widget._parent, Widget):
widget._parent.children._remove(widget) widget._parent.children._remove(widget)
widget._detach() widget._detach()
@@ -1164,9 +1179,16 @@ class App(Generic[ReturnType], DOMNode):
""" """
widget._attach(parent) widget._attach(parent)
widget._start_messages() widget._start_messages()
widget.post_message_no_wait(events.Mount(sender=parent))
def is_mounted(self, widget: Widget) -> bool: def is_mounted(self, widget: Widget) -> bool:
"""Check if a widget is mounted.
Args:
widget (Widget): A widget.
Returns:
bool: True of the widget is mounted.
"""
return widget in self._registry return widget in self._registry
async def _close_all(self) -> None: async def _close_all(self) -> None:
@@ -1388,22 +1410,22 @@ class App(Generic[ReturnType], DOMNode):
async def _on_resize(self, event: events.Resize) -> None: async def _on_resize(self, event: events.Resize) -> None:
event.stop() event.stop()
self.screen._screen_resized(event.size)
await self.screen.post_message(event) await self.screen.post_message(event)
async def _on_remove(self, event: events.Remove) -> None: async def _on_remove(self, event: events.Remove) -> None:
widget = event.widget widget = event.widget
parent = widget.parent parent = widget.parent
if parent is not None:
parent.refresh(layout=True) widget.reset_focus()
remove_widgets = widget.walk_children( remove_widgets = widget.walk_children(
Widget, with_self=True, method="depth", reverse=True Widget, with_self=True, method="depth", reverse=True
) )
for child in remove_widgets:
self._unregister(child)
for child in remove_widgets: for child in remove_widgets:
await child._close_messages() await child._close_messages()
self._unregister(child)
if parent is not None:
parent.refresh(layout=True)
async def action_press(self, key: str) -> None: async def action_press(self, key: str) -> None:
await self.press(key) await self.press(key)

View File

@@ -68,9 +68,16 @@ class ColorsApp(App):
BINDINGS = [("d", "toggle_dark", "Toggle dark mode")] BINDINGS = [("d", "toggle_dark", "Toggle dark mode")]
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
yield Content(ColorButtons(), ColorsView()) yield Content(ColorButtons())
yield Footer() yield Footer()
def on_mount(self) -> None:
self.call_later(self.update_view)
def update_view(self) -> None:
content = self.query_one("Content", Content)
content.mount(ColorsView())
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
self.query(ColorGroup).remove_class("-active") self.query(ColorGroup).remove_class("-active")
group = self.query_one(f"#group-{event.button.id}", ColorGroup) group = self.query_one(f"#group-{event.button.id}", ColorGroup)

View File

@@ -249,7 +249,7 @@ class Stylesheet:
css = css_file.read() css = css_file.read()
path = os.path.abspath(filename) path = os.path.abspath(filename)
except Exception as error: except Exception as error:
raise StylesheetError(f"unable to read {filename!r}; {error}") raise StylesheetError(f"unable to read CSS file {filename!r}") from None
self.source[str(path)] = CssSource(css, False, 0) self.source[str(path)] = CssSource(css, False, 0)
self._require_parse = True self._require_parse = True

View File

@@ -662,7 +662,7 @@ class DOMNode(MessagePump):
Defaults to None. Defaults to None.
with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True.
method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth". method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth".
reverse (bool, optional): Reverse the order (bottom up). Defaults to False reverse (bool, optional): Reverse the order (bottom up). Defaults to False.
Returns: Returns:
Iterable[DOMNode | WalkType]: An iterable of nodes. Iterable[DOMNode | WalkType]: An iterable of nodes.

View File

@@ -73,6 +73,7 @@ class MessagePump(metaclass=MessagePumpMeta):
self._timers: WeakSet[Timer] = WeakSet() self._timers: WeakSet[Timer] = WeakSet()
self._last_idle: float = time() self._last_idle: float = time()
self._max_idle: float | None = None self._max_idle: float | None = None
self._mounted_event = asyncio.Event()
@property @property
def task(self) -> Task: def task(self) -> Task:
@@ -278,18 +279,19 @@ class MessagePump(metaclass=MessagePumpMeta):
await timer.stop() await timer.stop()
self._timers.clear() self._timers.clear()
await self._message_queue.put(None) await self._message_queue.put(None)
if self._task is not None and asyncio.current_task() != self._task: if self._task is not None and asyncio.current_task() != self._task:
# Ensure everything is closed before returning # Ensure everything is closed before returning
await self._task await self._task
def _start_messages(self) -> None: def _start_messages(self) -> None:
"""Start messages task.""" """Start messages task."""
Reactive.initialize_object(self)
self._task = asyncio.create_task(self._process_messages()) self._task = asyncio.create_task(self._process_messages())
async def _process_messages(self) -> None: async def _process_messages(self) -> None:
self._running = True self._running = True
await self._pre_process()
try: try:
await self._process_messages_loop() await self._process_messages_loop()
except CancelledError: except CancelledError:
@@ -299,6 +301,18 @@ class MessagePump(metaclass=MessagePumpMeta):
for timer in list(self._timers): for timer in list(self._timers):
await timer.stop() await timer.stop()
async def _pre_process(self) -> None:
"""Procedure to run before processing messages."""
# Dispatch compose and mount messages without going through loop
# These events must occur in this order, and at the start.
try:
await self._dispatch_message(events.Compose(sender=self))
await self._dispatch_message(events.Mount(sender=self))
finally:
# This is critical, mount may be waiting
self._mounted_event.set()
Reactive._initialize_object(self)
async def _process_messages_loop(self) -> None: async def _process_messages_loop(self) -> None:
"""Process messages until the queue is closed.""" """Process messages until the queue is closed."""
_rich_traceback_guard = True _rich_traceback_guard = True
@@ -331,11 +345,15 @@ class MessagePump(metaclass=MessagePumpMeta):
except CancelledError: except CancelledError:
raise raise
except Exception as error: except Exception as error:
self._mounted_event.set()
self.app._handle_exception(error) self.app._handle_exception(error)
break break
finally: finally:
self._message_queue.task_done() self._message_queue.task_done()
current_time = time() current_time = time()
# Insert idle events
if self._message_queue.empty() or ( if self._message_queue.empty() or (
self._max_idle is not None self._max_idle is not None
and current_time - self._last_idle > self._max_idle and current_time - self._last_idle > self._max_idle

View File

@@ -20,6 +20,12 @@ if TYPE_CHECKING:
ReactiveType = TypeVar("ReactiveType") ReactiveType = TypeVar("ReactiveType")
class _NotSet:
pass
_NOT_SET = _NotSet()
T = TypeVar("T") T = TypeVar("T")
@@ -83,24 +89,31 @@ class Reactive(Generic[ReactiveType]):
return cls(default, layout=False, repaint=False, init=True) return cls(default, layout=False, repaint=False, init=True)
@classmethod @classmethod
def initialize_object(cls, obj: object) -> None: def _initialize_object(cls, obj: object) -> None:
"""Call any watchers / computes for the first time. """Set defaults and call any watchers / computes for the first time.
Args: Args:
obj (Reactable): An object with Reactive descriptors obj (Reactable): An object with Reactive descriptors
""" """
if not hasattr(obj, "__reactive_initialized"):
startswith = str.startswith startswith = str.startswith
for key in obj.__class__.__dict__.keys(): for key in obj.__class__.__dict__:
if startswith(key, "_init_"): if startswith(key, "_default_"):
name = key[6:] name = key[9:]
if not hasattr(obj, name): # Check defaults
default = getattr(obj, key) if not hasattr(obj, name):
setattr(obj, name, default() if callable(default) else default) # Attribute has no value yet
default = getattr(obj, key)
default_value = default() if callable(default) else default
# Set the default vale (calls `__set__`)
setattr(obj, name, default_value)
setattr(obj, "__reactive_initialized", True)
def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: def __set_name__(self, owner: Type[MessageTarget], name: str) -> None:
# Check for compute method
if hasattr(owner, f"compute_{name}"): if hasattr(owner, f"compute_{name}"):
# Compute methods are stored in a list called `__computes`
try: try:
computes = getattr(owner, "__computes") computes = getattr(owner, "__computes")
except AttributeError: except AttributeError:
@@ -108,31 +121,46 @@ class Reactive(Generic[ReactiveType]):
setattr(owner, "__computes", computes) setattr(owner, "__computes", computes)
computes.append(name) computes.append(name)
# The name of the attribute
self.name = name self.name = name
# The internal name where the attribute's value is stored
self.internal_name = f"_reactive_{name}" self.internal_name = f"_reactive_{name}"
default = self._default default = self._default
setattr(owner, f"_default_{name}", default)
if self._init:
setattr(owner, f"_init_{name}", default)
else:
setattr(
owner, self.internal_name, default() if callable(default) else default
)
def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType:
return getattr(obj, self.internal_name) value: _NotSet | ReactiveType = getattr(obj, self.internal_name, _NOT_SET)
if isinstance(value, _NotSet):
# No value present, we need to set the default
init_name = f"_default_{self.name}"
default = getattr(obj, init_name)
default_value = default() if callable(default) else default
# Set and return the value
setattr(obj, self.internal_name, default_value)
if self._init:
self._check_watchers(obj, self.name, default_value, first_set=True)
return default_value
return value
def __set__(self, obj: Reactable, value: ReactiveType) -> None: def __set__(self, obj: Reactable, value: ReactiveType) -> None:
name = self.name name = self.name
current_value = getattr(obj, self.internal_name, None) current_value = getattr(obj, name)
# Check for validate function
validate_function = getattr(obj, f"validate_{name}", None) validate_function = getattr(obj, f"validate_{name}", None)
# Check if this is the first time setting the value
first_set = getattr(obj, f"__first_set_{self.internal_name}", True) first_set = getattr(obj, f"__first_set_{self.internal_name}", True)
# Call validate, but not on first set.
if callable(validate_function) and not first_set: if callable(validate_function) and not first_set:
value = validate_function(value) value = validate_function(value)
# If the value has changed, or this is the first time setting the value
if current_value != value or first_set: if current_value != value or first_set:
# Set the first set flag to False
setattr(obj, f"__first_set_{self.internal_name}", False) setattr(obj, f"__first_set_{self.internal_name}", False)
# Store the internal value
setattr(obj, self.internal_name, value) setattr(obj, self.internal_name, value)
# Check all watchers
self._check_watchers(obj, name, current_value, first_set=first_set) self._check_watchers(obj, name, current_value, first_set=first_set)
# Refresh according to descriptor flags
if self._layout or self._repaint: if self._layout or self._repaint:
obj.refresh(repaint=self._repaint, layout=self._layout) obj.refresh(repaint=self._repaint, layout=self._layout)
@@ -140,50 +168,77 @@ class Reactive(Generic[ReactiveType]):
def _check_watchers( def _check_watchers(
cls, obj: Reactable, name: str, old_value: Any, first_set: bool = False cls, obj: Reactable, name: str, old_value: Any, first_set: bool = False
) -> None: ) -> None:
"""Check watchers, and call watch methods / computes
Args:
obj (Reactable): The reactable object.
name (str): Attribute name.
old_value (Any): The old (previous) value of the attribute.
first_set (bool, optional): True if this is the first time setting the value. Defaults to False.
"""
# Get the current value.
internal_name = f"_reactive_{name}" internal_name = f"_reactive_{name}"
value = getattr(obj, internal_name) value = getattr(obj, internal_name)
async def update_watcher( async def update_watcher(
obj: Reactable, watch_function: Callable, old_value: Any, value: Any obj: Reactable, watch_function: Callable, old_value: Any, value: Any
) -> None: ) -> None:
"""Call watch function, and run compute.
Args:
obj (Reactable): Reactable object.
watch_function (Callable): Watch method.
old_value (Any): Old value.
value (Any): new value.
"""
_rich_traceback_guard = True _rich_traceback_guard = True
# Call watch with one or two parameters
if count_parameters(watch_function) == 2: if count_parameters(watch_function) == 2:
watch_result = watch_function(old_value, value) watch_result = watch_function(old_value, value)
else: else:
watch_result = watch_function(value) watch_result = watch_function(value)
# Optionally await result
if isawaitable(watch_result): if isawaitable(watch_result):
await watch_result await watch_result
# Run computes
await Reactive._compute(obj) await Reactive._compute(obj)
# Check for watch method
watch_function = getattr(obj, f"watch_{name}", None) watch_function = getattr(obj, f"watch_{name}", None)
if callable(watch_function): if callable(watch_function):
# Post a callback message, so we can call the watch method in an orderly async manner
obj.post_message_no_wait( obj.post_message_no_wait(
events.Callback( events.Callback(
obj, sender=obj,
callback=partial( callback=partial(
update_watcher, obj, watch_function, old_value, value update_watcher, obj, watch_function, old_value, value
), ),
) )
) )
# Check for watchers set via `watch`
watcher_name = f"__{name}_watchers" watcher_name = f"__{name}_watchers"
watchers = getattr(obj, watcher_name, ()) watchers = getattr(obj, watcher_name, ())
for watcher in watchers: for watcher in watchers:
obj.post_message_no_wait( obj.post_message_no_wait(
events.Callback( events.Callback(
obj, sender=obj,
callback=partial(update_watcher, obj, watcher, old_value, value), callback=partial(update_watcher, obj, watcher, old_value, value),
) )
) )
if not first_set: # Run computes
obj.post_message_no_wait( obj.post_message_no_wait(
events.Callback(obj, callback=partial(Reactive._compute, obj)) events.Callback(sender=obj, callback=partial(Reactive._compute, obj))
) )
@classmethod @classmethod
async def _compute(cls, obj: Reactable) -> None: async def _compute(cls, obj: Reactable) -> None:
"""Invoke all computes.
Args:
obj (Reactable): Reactable object.
"""
_rich_traceback_guard = True _rich_traceback_guard = True
computes = getattr(obj, "__computes", []) computes = getattr(obj, "__computes", [])
for compute in computes: for compute in computes:

View File

@@ -387,12 +387,12 @@ class Screen(Widget):
def _on_screen_resume(self) -> None: def _on_screen_resume(self) -> None:
"""Called by the App""" """Called by the App"""
size = self.app.size size = self.app.size
self._refresh_layout(size, full=True) self._refresh_layout(size, full=True)
async def _on_resize(self, event: events.Resize) -> None: async def _on_resize(self, event: events.Resize) -> None:
event.stop() event.stop()
self._screen_resized(event.size)
async def _handle_mouse_move(self, event: events.MouseMove) -> None: async def _handle_mouse_move(self, event: events.MouseMove) -> None:
try: try:

View File

@@ -1,10 +1,20 @@
from __future__ import annotations from __future__ import annotations
from asyncio import Lock from asyncio import Lock, wait, create_task
from fractions import Fraction from fractions import Fraction
from itertools import islice from itertools import islice
from operator import attrgetter from operator import attrgetter
from typing import TYPE_CHECKING, ClassVar, Collection, Iterable, NamedTuple, cast from typing import (
Awaitable,
Generator,
TYPE_CHECKING,
ClassVar,
Collection,
Iterable,
NamedTuple,
Sequence,
cast,
)
import rich.repr import rich.repr
from rich.console import ( from rich.console import (
@@ -59,6 +69,28 @@ _JUSTIFY_MAP: dict[str, JustifyMethod] = {
} }
class AwaitMount:
"""An awaitable returned by mount() and mount_all().
Example:
await self.mount(Static("foo"))
"""
def __init__(self, widgets: Sequence[Widget]) -> None:
self._widgets = widgets
def __await__(self) -> Generator[None, None, None]:
async def await_mount() -> None:
aws = [
create_task(widget._mounted_event.wait()) for widget in self._widgets
]
if aws:
await wait(aws)
return await_mount().__await__()
class _Styled: class _Styled:
"""Apply a style to a renderable. """Apply a style to a renderable.
@@ -316,7 +348,7 @@ class Widget(DOMNode):
"""Clear arrangement cache, forcing a new arrange operation.""" """Clear arrangement cache, forcing a new arrange operation."""
self._arrangement = None self._arrangement = None
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: def mount(self, *anon_widgets: Widget, **widgets: Widget) -> AwaitMount:
"""Mount child widgets (making this widget a container). """Mount child widgets (making this widget a container).
Widgets may be passed as positional arguments or keyword arguments. If keyword arguments, Widgets may be passed as positional arguments or keyword arguments. If keyword arguments,
@@ -327,9 +359,12 @@ class Widget(DOMNode):
self.mount(Static("hello"), header=Header()) self.mount(Static("hello"), header=Header())
``` ```
Returns:
AwaitMount: An awaitable that waits for widgets to be mounted.
""" """
self.app._register(self, *anon_widgets, **widgets) mounted_widgets = self.app._register(self, *anon_widgets, **widgets)
self.app.screen.refresh(layout=True) return AwaitMount(mounted_widgets)
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
"""Called by Textual to create child widgets. """Called by Textual to create child widgets.
@@ -1812,7 +1847,19 @@ class Widget(DOMNode):
scroll_visible (bool, optional): Scroll parent to make this widget scroll_visible (bool, optional): Scroll parent to make this widget
visible. Defaults to True. visible. Defaults to True.
""" """
self.screen.set_focus(self, scroll_visible=scroll_visible)
def set_focus(widget: Widget):
"""Callback to set the focus."""
widget.screen.set_focus(self, scroll_visible=scroll_visible)
self.app.call_later(set_focus, self)
def reset_focus(self) -> None:
"""Reset the focus (move it to the next available widget)."""
try:
self.screen._reset_focus(self)
except NoScreen:
pass
def capture_mouse(self, capture: bool = True) -> None: def capture_mouse(self, capture: bool = True) -> None:
"""Capture (or release) the mouse. """Capture (or release) the mouse.
@@ -1857,10 +1904,11 @@ class Widget(DOMNode):
await self.action(binding.action) await self.action(binding.action)
return True return True
async def _on_compose(self, event: events.Compose) -> None:
widgets = list(self.compose())
await self.mount(*widgets)
def _on_mount(self, event: events.Mount) -> None: def _on_mount(self, event: events.Mount) -> None:
widgets = self.compose()
self.mount(*widgets)
# Preset scrollbars if not automatic
if self.styles.overflow_y == "scroll": if self.styles.overflow_y == "scroll":
self.show_vertical_scrollbar = True self.show_vertical_scrollbar = True
if self.styles.overflow_x == "scroll": if self.styles.overflow_x == "scroll":
@@ -1932,7 +1980,7 @@ class Widget(DOMNode):
def _on_hide(self, event: events.Hide) -> None: def _on_hide(self, event: events.Hide) -> None:
if self.has_focus: if self.has_focus:
self.screen._reset_focus(self) self.reset_focus()
def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None: def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None:
self.scroll_to_region(message.region, animate=True) self.scroll_to_region(message.region, animate=True)