From d18c794e69c27c48211d761a0641c80844b55b9b Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Sun, 19 Feb 2023 22:24:28 +0000 Subject: [PATCH] call compute on demand --- src/textual/reactive.py | 18 ++++++++++-------- tests/test_reactive.py | 27 +++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 4c7bd1d98..794d723b0 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -143,23 +143,25 @@ class Reactive(Generic[ReactiveType]): self.name = name # The internal name where the attribute's value is stored self.internal_name = f"_reactive_{name}" + self.compute_name = f"compute_{name}" default = self._default setattr(owner, f"_default_{name}", default) def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: - _rich_traceback_omit = True - - self._initialize_reactive(obj, self.name) + internal_name = self.internal_name + if not hasattr(obj, internal_name): + self._initialize_reactive(obj, self.name) value: ReactiveType - compute_method = getattr(self, f"compute_{self.name}", None) + compute_method = getattr(obj, self.compute_name, None) if compute_method is not None: - old_value = getattr(obj, self.internal_name) - value = getattr(obj, f"compute_{self.name}")() - setattr(obj, self.internal_name, value) + old_value = getattr(obj, internal_name) + _rich_traceback_omit = True + value = compute_method() + setattr(obj, internal_name, value) self._check_watchers(obj, self.name, old_value) else: - value = getattr(obj, self.internal_name) + value = getattr(obj, internal_name) return value def __set__(self, obj: Reactable, value: ReactiveType) -> None: diff --git a/tests/test_reactive.py b/tests/test_reactive.py index da8be66ae..2a86bdba2 100644 --- a/tests/test_reactive.py +++ b/tests/test_reactive.py @@ -328,6 +328,27 @@ async def test_reactive_inheritance(): assert tertiary.baz == "baz" +async def test_compute(): + """Check compute method is called.""" + + class ComputeApp(App): + count = var(0) + count_double = var(0) + + def compute_count_double(self) -> int: + return self.count * 2 + + app = ComputeApp() + + async with app.run_test(): + assert app.count_double == 0 + app.count = 1 + assert app.count_double == 2 + assert app.count_double == 2 + app.count = 2 + assert app.count_double == 4 + + async def test_watch_compute(): """Check that watching a computed attribute works.""" @@ -347,7 +368,9 @@ async def test_watch_compute(): app = Calculator() - async with app.run_test() as pilot: + # Referencing the value calls compute + # Setting any reactive values calls compute + async with app.run_test(): assert app.show_ac is True app.value = "1" assert app.show_ac is False @@ -356,4 +379,4 @@ async def test_watch_compute(): app.numbers = "123" assert app.show_ac is False - assert watch_called == [True, False, True, False] + assert watch_called == [True, True, False, False, True, True, False, False]