diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 4c7bd1d98..0553d076c 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -143,24 +143,25 @@ class Reactive(Generic[ReactiveType]): self.name = name # The internal name where the attribute's value is stored self.internal_name = f"_reactive_{name}" + self.compute_name = f"compute_{name}" default = self._default setattr(owner, f"_default_{name}", default) def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: - _rich_traceback_omit = True + internal_name = self.internal_name + if not hasattr(obj, internal_name): + self._initialize_reactive(obj, self.name) - self._initialize_reactive(obj, self.name) - - value: ReactiveType - compute_method = getattr(self, f"compute_{self.name}", None) - if compute_method is not None: - old_value = getattr(obj, self.internal_name) - value = getattr(obj, f"compute_{self.name}")() - setattr(obj, self.internal_name, value) + if hasattr(obj, self.compute_name): + value: ReactiveType + old_value = getattr(obj, internal_name) + _rich_traceback_omit = True + value = getattr(obj, self.compute_name)() + setattr(obj, internal_name, value) self._check_watchers(obj, self.name, old_value) + return value else: - value = getattr(obj, self.internal_name) - return value + return getattr(obj, internal_name) def __set__(self, obj: Reactable, value: ReactiveType) -> None: _rich_traceback_omit = True diff --git a/tests/test_reactive.py b/tests/test_reactive.py index da8be66ae..9c824645e 100644 --- a/tests/test_reactive.py +++ b/tests/test_reactive.py @@ -328,6 +328,33 @@ async def test_reactive_inheritance(): assert tertiary.baz == "baz" +async def test_compute(): + """Check compute method is called.""" + + class ComputeApp(App): + count = var(0) + count_double = var(0) + + def __init__(self) -> None: + self.start = 0 + super().__init__() + + def compute_count_double(self) -> int: + return self.start + self.count * 2 + + app = ComputeApp() + + async with app.run_test(): + assert app.count_double == 0 + app.count = 1 + assert app.count_double == 2 + assert app.count_double == 2 + app.count = 2 + assert app.count_double == 4 + app.start = 10 + assert app.count_double == 14 + + async def test_watch_compute(): """Check that watching a computed attribute works.""" @@ -347,7 +374,9 @@ async def test_watch_compute(): app = Calculator() - async with app.run_test() as pilot: + # Referencing the value calls compute + # Setting any reactive values calls compute + async with app.run_test(): assert app.show_ac is True app.value = "1" assert app.show_ac is False @@ -356,4 +385,4 @@ async def test_watch_compute(): app.numbers = "123" assert app.show_ac is False - assert watch_called == [True, False, True, False] + assert watch_called == [True, True, False, False, True, True, False, False]