diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 2ee3ba5df..6e281e2df 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -2,7 +2,16 @@ from __future__ import annotations from functools import partial from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Callable, Generic, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Type, + TypeVar, + Union, +) from . import events from ._callback import count_parameters, invoke @@ -146,6 +155,7 @@ class Reactive(Generic[ReactiveType]): setattr(owner, f"_default_{name}", default) def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: + _rich_traceback_omit = True value: _NotSet | ReactiveType = getattr(obj, self.internal_name, _NOT_SET) if isinstance(value, _NotSet): # No value present, we need to set the default @@ -160,6 +170,7 @@ class Reactive(Generic[ReactiveType]): return value def __set__(self, obj: Reactable, value: ReactiveType) -> None: + _rich_traceback_omit = True name = self.name current_value = getattr(obj, name) # Check for validate function @@ -193,55 +204,50 @@ class Reactive(Generic[ReactiveType]): old_value (Any): The old (previous) value of the attribute. first_set (bool, optional): True if this is the first time setting the value. Defaults to False. """ + _rich_traceback_omit = True # Get the current value. internal_name = f"_reactive_{name}" value = getattr(obj, internal_name) - async def update_watcher( - obj: Reactable, watch_function: Callable, old_value: Any, value: Any + async def await_watcher(awaitable: Awaitable) -> None: + """Coroutine to await an awaitable returned from a watcher""" + _rich_traceback_omit = True + await awaitable + # Watcher may have changed the state, so run compute again + obj.post_message_no_wait( + events.Callback(sender=obj, callback=partial(Reactive._compute, obj)) + ) + + def invoke_watcher( + watch_function: Callable, old_value: object, value: object ) -> None: - """Call watch function, and run compute. + """Invoke a watch function. Args: - obj (Reactable): Reactable object. - watch_function (Callable): Watch method. - old_value (Any): Old value. - value (Any): new value. + watch_function (Callable): A watch function, which may be sync or async. + old_value (object): The old value of the attribute. + value (object): The new value of the attribute. """ - _rich_traceback_guard = True - # Call watch with one or two parameters + _rich_traceback_omit = True if count_parameters(watch_function) == 2: watch_result = watch_function(old_value, value) else: watch_result = watch_function(value) - # Optionally await result if isawaitable(watch_result): - await watch_result - # Run computes - await Reactive._compute(obj) + # Result is awaitable, so we need to await it within an async context + obj.post_message_no_wait( + events.Callback( + sender=obj, callback=partial(await_watcher, watch_result) + ) + ) - # Check for watch method watch_function = getattr(obj, f"watch_{name}", None) if callable(watch_function): - # Post a callback message, so we can call the watch method in an orderly async manner - obj.post_message_no_wait( - events.Callback( - sender=obj, - callback=partial( - update_watcher, obj, watch_function, old_value, value - ), - ) - ) + invoke_watcher(watch_function, old_value, value) - # Check for watchers set via `watch` watchers: list[Callable] = getattr(obj, "__watchers", {}).get(name, []) for watcher in watchers: - obj.post_message_no_wait( - events.Callback( - sender=obj, - callback=partial(update_watcher, obj, watcher, old_value, value), - ) - ) + invoke_watcher(watcher, old_value, value) # Run computes obj.post_message_no_wait( @@ -301,10 +307,13 @@ class var(Reactive[ReactiveType]): Args: default (ReactiveType | Callable[[], ReactiveType]): A default value or callable that returns a default. + init (bool, optional): Call watchers on initialize (post mount). Defaults to True. """ - def __init__(self, default: ReactiveType | Callable[[], ReactiveType]) -> None: - super().__init__(default, layout=False, repaint=False, init=True) + def __init__( + self, default: ReactiveType | Callable[[], ReactiveType], init: bool = True + ) -> None: + super().__init__(default, layout=False, repaint=False, init=init) def watch( diff --git a/tests/test_reactive.py b/tests/test_reactive.py new file mode 100644 index 000000000..7158b46f2 --- /dev/null +++ b/tests/test_reactive.py @@ -0,0 +1,26 @@ +from textual.app import App, ComposeResult +from textual.reactive import reactive + + +class WatchApp(App): + + count = reactive(0, init=False) + + test_count = 0 + + def watch_count(self, value: int) -> None: + self.test_count = value + + +async def test_watch(): + """Test that changes to a watched reactive attribute happen immediately.""" + app = WatchApp() + async with app.run_test(): + app.count += 1 + assert app.test_count == 1 + app.count += 1 + assert app.test_count == 2 + app.count -= 1 + assert app.test_count == 1 + app.count -= 1 + assert app.test_count == 0