From c3129c8331b8a1b341079ed1596ce4969dc72d2b Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Fri, 13 Jan 2023 17:22:52 +0000 Subject: [PATCH] fix inheritance --- src/textual/dom.py | 13 +++++++++ src/textual/reactive.py | 59 ++++++++++++++++----------------------- tests/test_reactive.py | 62 ++++++++++++++++++++++++++++++++++------- 3 files changed, 89 insertions(+), 45 deletions(-) diff --git a/src/textual/dom.py b/src/textual/dom.py index c1152e8b0..33469be24 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -108,6 +108,8 @@ class DOMNode(MessagePump): # Generated list of bindings _merged_bindings: ClassVar[Bindings] | None = None + _reactives: ClassVar[dict[str, Reactive]] + def __init__( self, *, @@ -164,6 +166,17 @@ class DOMNode(MessagePump): cls, inherit_css: bool = True, inherit_bindings: bool = True ) -> None: super().__init_subclass__() + + reactives = cls._reactives = {} + for base in reversed(cls.__mro__): + reactives.update( + { + name: reactive + for name, reactive in base.__dict__.items() + if isinstance(reactive, Reactive) + } + ) + cls._inherit_css = inherit_css cls._inherit_bindings = inherit_bindings css_type_names: set[str] = set() diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 686de95b8..4a65201d0 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -13,6 +13,8 @@ from typing import ( Union, ) +import rich.repr + from . import events from ._callback import count_parameters, invoke from ._types import MessageTarget @@ -35,6 +37,7 @@ _NOT_SET = _NotSet() T = TypeVar("T") +@rich.repr.auto class Reactive(Generic[ReactiveType]): """Reactive descriptor. @@ -47,6 +50,8 @@ class Reactive(Generic[ReactiveType]): no_compute (bool, optional): Don't run compute methods when attribute is changed. Defaults to False. """ + _reactives: TypeVar[dict[str, object]] = {} + def __init__( self, default: ReactiveType | Callable[[], ReactiveType], @@ -65,6 +70,14 @@ class Reactive(Generic[ReactiveType]): self._no_compute = no_compute self._is_compute = False + def __rich_repr__(self) -> rich.repr.Result: + yield self._default + yield "layout", self._layout + yield "repaint", self._repaint + yield "init", self._init + yield "always_update", self._always_update + yield "no_compute", self._no_compute + @classmethod def init( cls, @@ -111,12 +124,13 @@ class Reactive(Generic[ReactiveType]): """ return cls(default, layout=False, repaint=False, init=True) - def _initialize_reactive(self, obj: Reactable, name: str) -> None: + def _initialize_reactive(self, obj: Reactable, name: str) -> bool: internal_name = f"_reactive_{name}" if hasattr(obj, internal_name): # Attribute already has a value return - if self._is_compute: + compute_method = getattr(obj, f"compute_{name}", None) + if compute_method is not None and self._init: default = getattr(obj, f"compute_{name}")() else: default_or_callable = self._default @@ -136,32 +150,10 @@ class Reactive(Generic[ReactiveType]): Args: obj (Reactable): An object with Reactive descriptors """ - reactives = getattr(obj, "__reactives", {}) - for name, reactive in reactives.items(): + + for name, reactive in obj._reactives.items(): reactive._initialize_reactive(obj, name) - # startswith = str.startswith - # watchers = [] - # reactives = getattr(obj, "__reactives", []) - - # print(reactives) - # for name in reactives.keys(): - # internal_name = f"_reactive_{name}" - # # Check defaults - # if internal_name not in obj.__dict__: - # # Attribute has no value yet - - # for k in obj.__dict__: - # if k.startswith("_default"): - # print(k) - - # default = getattr(obj, f"_default_{name}") - # default_value = default() if callable(default) else default - # # Set the default vale (calls `__set__`) - # obj.__dict__[internal_name] = None - # setattr(obj, name, default_value) - # # watchers.append((name, default_value)) - @classmethod def _reset_object(cls, obj: object) -> None: """Reset reactive structures on object (to avoid reference cycles). @@ -173,9 +165,6 @@ class Reactive(Generic[ReactiveType]): getattr(obj, "__computes", []).clear() def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: - reactives = getattr(owner, "__reactives", {}) - reactives[name] = self - setattr(owner, "__reactives", reactives) # Check for compute method if hasattr(owner, f"compute_{name}"): @@ -198,16 +187,18 @@ class Reactive(Generic[ReactiveType]): def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: _rich_traceback_omit = True + # Reactive._initialize_object(obj) self._initialize_reactive(obj, self.name) value: _NotSet | ReactiveType - if self._is_compute: + compute_method = getattr(self, f"compute_{self.name}", None) + if compute_method is not None: value = getattr(obj, f"compute_{self.name}")() else: - value = getattr(obj, self.internal_name, _NOT_SET) + value = getattr(obj, self.internal_name) - if not self._no_compute: - self._compute(obj) + # if not self._no_compute: + # self._compute(obj) # if isinstance(value, _NotSet): # # No value present, we need to set the default @@ -307,8 +298,6 @@ class Reactive(Generic[ReactiveType]): else: return True - # Compute is only required if a watcher runs immediately, not if they were posted. - require_compute = False watch_function = getattr(obj, f"watch_{name}", None) if callable(watch_function): invoke_watcher(watch_function, old_value, value) diff --git a/tests/test_reactive.py b/tests/test_reactive.py index 3d03c6d49..eef5a9423 100644 --- a/tests/test_reactive.py +++ b/tests/test_reactive.py @@ -3,7 +3,7 @@ import asyncio import pytest from textual.app import App, ComposeResult -from textual.reactive import reactive, var +from textual.reactive import Reactive, reactive, var from textual.widget import Widget OLD_VALUE = 5_000 @@ -157,15 +157,15 @@ async def test_reactive_always_update(): async def test_reactive_with_callable_default(): """A callable can be supplied as the default value for a reactive. Textual will call it in order to retrieve the default value.""" - called_with_app = None + called_with_app = False def set_called() -> int: nonlocal called_with_app - called_with_app = app + called_with_app = True return OLD_VALUE class ReactiveCallable(App): - value = reactive(set_called) + value = reactive(lambda: 123) watcher_called_with = None def watch_value(self, new_value): @@ -174,12 +174,8 @@ async def test_reactive_with_callable_default(): app = ReactiveCallable() async with app.run_test(): assert ( - app.value == OLD_VALUE + app.value == 123 ) # The value should be set to the return val of the callable - assert ( - called_with_app is app - ) # Ensure the App is passed into the reactive default callable - assert app.watcher_called_with == OLD_VALUE async def test_validate_init_true(): @@ -240,7 +236,7 @@ async def test_reactive_compute_first_time_set(): async def test_reactive_method_call_order(): class CallOrder(App): count = reactive(OLD_VALUE, init=False) - count_times_ten = reactive(OLD_VALUE * 10) + count_times_ten = reactive(OLD_VALUE * 10, init=False) calls = [] def validate_count(self, value: int) -> int: @@ -295,3 +291,49 @@ async def test_premature_reactive_call(): async with app.run_test() as pilot: assert watcher_called app.exit() + + +async def test_reactive_inheritance(): + """Check that inheritance works as expected for reactives.""" + + class Primary(App): + foo = reactive(1) + bar = reactive("bar") + + class Secondary(Primary): + foo = reactive(2) + egg = reactive("egg") + + class Tertiary(Secondary): + baz = reactive("baz") + + from rich import print + + primary = Primary() + secondary = Secondary() + tertiary = Tertiary() + + primary_reactive_count = len(primary._reactives) + + # Secondary adds one new reactive + assert len(secondary._reactives) == primary_reactive_count + 1 + + Reactive._initialize_object(primary) + Reactive._initialize_object(secondary) + Reactive._initialize_object(tertiary) + + # Primary doesn't have egg + with pytest.raises(AttributeError): + assert primary.egg + + # primary has foo of 1 + assert primary.foo == 1 + # secondary has different reactive + assert secondary.foo == 2 + # foo is accessible through tertiary + assert tertiary.foo == 2 + + with pytest.raises(AttributeError): + secondary.baz + + assert tertiary.baz == "baz"