From 3749412d2fe38fca41ee633a93e5f69e368910d4 Mon Sep 17 00:00:00 2001 From: darrenburns Date: Fri, 19 Aug 2022 14:55:53 +0100 Subject: [PATCH] Animation callback (#683) * Add on_complete callback to Animations * Add tests for on_complete Animation callbacks * Update __textual_animation__ signature to include callback * Support awaitable callbacks for animator on_complete * Import tidying * Update animator tests --- sandbox/darren/just_a_box.py | 41 +++++++++-------------------- src/textual/_animator.py | 29 +++++++++++++++++--- src/textual/_types.py | 7 +++-- src/textual/css/scalar_animation.py | 7 ++--- src/textual/css/styles.py | 6 +++-- src/textual/messages.py | 6 ++--- src/textual/screen.py | 2 +- tests/test_animator.py | 38 ++++++++++++++++++++------ 8 files changed, 83 insertions(+), 53 deletions(-) diff --git a/sandbox/darren/just_a_box.py b/sandbox/darren/just_a_box.py index 781c66f67..ee8be631e 100644 --- a/sandbox/darren/just_a_box.py +++ b/sandbox/darren/just_a_box.py @@ -1,11 +1,11 @@ from __future__ import annotations +import asyncio + from rich.console import RenderableType -from rich.panel import Panel from textual import events from textual.app import App, ComposeResult -from textual.layout import Horizontal, Vertical from textual.widget import Widget @@ -18,37 +18,22 @@ class Box(Widget, can_focus=True): super().__init__(*children, id=id, classes=classes) def render(self) -> RenderableType: - return Panel("Box") + return "Box" class JustABox(App): def compose(self) -> ComposeResult: - yield Horizontal( - Vertical( - Box(id="box1", classes="box"), - Box(id="box2", classes="box"), - # Box(id="box3", classes="box"), - # Box(id="box4", classes="box"), - # Box(id="box5", classes="box"), - # Box(id="box6", classes="box"), - # Box(id="box7", classes="box"), - # Box(id="box8", classes="box"), - # Box(id="box9", classes="box"), - # Box(id="box10", classes="box"), - id="left_pane", - ), - Box(id="middle_pane"), - Vertical( - Box(id="boxa", classes="box"), - Box(id="boxb", classes="box"), - Box(id="boxc", classes="box"), - id="right_pane", - ), - id="horizontal", - ) + self.box = Box() + yield self.box - def key_p(self): - print(self.query("#horizontal").first().styles.layout) + def key_a(self): + self.animator.animate( + self.box.styles, + "opacity", + value=0.0, + duration=2.0, + on_complete=self.box.remove, + ) async def on_key(self, event: events.Key) -> None: await self.dispatch_key(event) diff --git a/src/textual/_animator.py b/src/textual/_animator.py index dbc5d9670..198c0e266 100644 --- a/src/textual/_animator.py +++ b/src/textual/_animator.py @@ -8,9 +8,10 @@ from typing import Any, Callable, TypeVar from dataclasses import dataclass from . import _clock +from ._callback import invoke from ._easing import DEFAULT_EASING, EASING from ._timer import Timer -from ._types import MessageTarget +from ._types import MessageTarget, CallbackType if sys.version_info >= (3, 8): from typing import Protocol, runtime_checkable @@ -30,8 +31,20 @@ class Animatable(Protocol): class Animation(ABC): + + on_complete: CallbackType | None = None + """Callback to run after animation completes""" + @abstractmethod def __call__(self, time: float) -> bool: # pragma: no cover + """Call the animation, return a boolean indicating whether animation is in-progress or complete. + + Args: + time (float): The current timestamp + + Returns: + bool: True if the animation has finished, otherwise False. + """ raise NotImplementedError("") def __eq__(self, other: object) -> bool: @@ -48,6 +61,7 @@ class SimpleAnimation(Animation): end_value: float | Animatable final_value: object easing: EasingFunction + on_complete: CallbackType | None = None def __call__(self, time: float) -> bool: @@ -109,6 +123,7 @@ class BoundAnimator: duration: float | None = None, speed: float | None = None, easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, ) -> None: easing_function = EASING[easing] if isinstance(easing, str) else easing return self._animator.animate( @@ -119,6 +134,7 @@ class BoundAnimator: duration=duration, speed=speed, easing=easing_function, + on_complete=on_complete, ) @@ -163,6 +179,7 @@ class Animator: duration: float | None = None, speed: float | None = None, easing: EasingFunction | str = DEFAULT_EASING, + on_complete: CallbackType | None = None, ) -> None: """Animate an attribute to a new value. @@ -201,6 +218,7 @@ class Animator: duration=duration, speed=speed, easing=easing_function, + on_complete=on_complete, ) if animation is None: start_value = getattr(obj, attribute) @@ -223,6 +241,7 @@ class Animator: end_value=value, final_value=final_value, easing=easing_function, + on_complete=on_complete, ) assert animation is not None, "animation expected to be non-None" @@ -233,7 +252,7 @@ class Animator: self._animations[animation_key] = animation self._timer.resume() - def __call__(self) -> None: + async def __call__(self) -> None: if not self._animations: self._timer.pause() else: @@ -241,7 +260,11 @@ class Animator: animation_keys = list(self._animations.keys()) for animation_key in animation_keys: animation = self._animations[animation_key] - if animation(animation_time): + animation_complete = animation(animation_time) + if animation_complete: + completion_callback = animation.on_complete + if completion_callback is not None: + await invoke(completion_callback) del self._animations[animation_key] def _get_time(self) -> float: diff --git a/src/textual/_types.py b/src/textual/_types.py index 180d02be4..7158e31a4 100644 --- a/src/textual/_types.py +++ b/src/textual/_types.py @@ -1,5 +1,6 @@ import sys -from typing import Awaitable, Callable, List, Optional, TYPE_CHECKING +from typing import Awaitable, Callable, List, TYPE_CHECKING, Union + from rich.segment import Segment if sys.version_info >= (3, 8): @@ -11,8 +12,6 @@ else: if TYPE_CHECKING: from .message import Message -Callback = Callable[[], None] - class MessageTarget(Protocol): async def post_message(self, message: "Message") -> bool: @@ -34,5 +33,5 @@ class EventTarget(Protocol): MessageHandler = Callable[["Message"], Awaitable] - Lines = List[List[Segment]] +CallbackType = Union[Callable[[], Awaitable[None]], Callable[[], None]] diff --git a/src/textual/css/scalar_animation.py b/src/textual/css/scalar_animation.py index ae37249f9..aa81ec805 100644 --- a/src/textual/css/scalar_animation.py +++ b/src/textual/css/scalar_animation.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable -from .. import events, log +from .._types import CallbackType from ..geometry import Offset from .._animator import Animation from .scalar import ScalarOffset @@ -25,6 +25,7 @@ class ScalarAnimation(Animation): duration: float | None, speed: float | None, easing: EasingFunction, + on_complete: CallbackType | None = None, ): assert ( speed is not None or duration is not None @@ -35,6 +36,7 @@ class ScalarAnimation(Animation): self.attribute = attribute self.final_value = value self.easing = easing + self.on_complete = on_complete size = widget.outer_size viewport = widget.app.size @@ -55,7 +57,6 @@ class ScalarAnimation(Animation): eased_factor = self.easing(factor) if eased_factor >= 1: - offset = self.final_value setattr(self.styles, self.attribute, self.final_value) return True diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index dd0b53e6f..b84a9c4a4 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, cast import rich.repr from rich.style import Style +from textual._types import CallbackType from .._animator import Animation, EasingFunction from ..color import Color from ..geometry import Offset, Spacing @@ -579,9 +580,9 @@ class Styles(StylesBase): duration: float | None, speed: float | None, easing: EasingFunction, + on_complete: CallbackType | None = None, ) -> Animation | None: - from ..widget import Widget - + # from ..widget import Widget # node = self.node # assert isinstance(self.node, Widget) if isinstance(value, ScalarOffset): @@ -594,6 +595,7 @@ class Styles(StylesBase): duration=duration, speed=speed, easing=easing, + on_complete=on_complete, ) return None diff --git a/src/textual/messages.py b/src/textual/messages.py index 84bc8d14a..58e4cd2a6 100644 --- a/src/textual/messages.py +++ b/src/textual/messages.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Awaitable, Union +from typing import TYPE_CHECKING import rich.repr +from ._types import CallbackType from .message import Message @@ -11,9 +12,6 @@ if TYPE_CHECKING: from .widget import Widget -CallbackType = Union[Callable[[], Awaitable[None]], Callable[[], None]] - - @rich.repr.auto class Update(Message, verbosity=3): def __init__(self, sender: MessagePump, widget: Widget): diff --git a/src/textual/screen.py b/src/textual/screen.py index b8f6af01f..260eca752 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -12,7 +12,7 @@ from ._callback import invoke from .geometry import Offset, Region, Size from ._compositor import Compositor, MapGeometry -from .messages import CallbackType +from ._types import CallbackType from .reactive import Reactive from .renderables.blank import Blank from ._timer import Timer diff --git a/tests/test_animator.py b/tests/test_animator.py index fd3f7c038..926949b27 100644 --- a/tests/test_animator.py +++ b/tests/test_animator.py @@ -1,12 +1,10 @@ from __future__ import annotations - from dataclasses import dataclass from unittest.mock import Mock import pytest - from textual._animator import Animator, SimpleAnimation from textual._easing import EASING, DEFAULT_EASING @@ -184,8 +182,7 @@ class MockAnimator(Animator): return self._time -def test_animator(): - +async def test_animator(): target = Mock() animator = MockAnimator(target) animate_test = AnimateTest() @@ -206,11 +203,11 @@ def test_animator(): assert animator._animations[(id(animate_test), "foo")] == expected assert not animator._on_animation_frame_called - animator() + await animator() assert animate_test.foo == 0 animator._time = 5 - animator() + await animator() assert animate_test.foo == 50 # New animation in the middle of an existing one @@ -218,12 +215,11 @@ def test_animator(): assert animate_test.foo == 50 animator._time = 6 - animator() + await animator() assert animate_test.foo == 200 def test_bound_animator(): - target = Mock() animator = MockAnimator(target) animate_test = AnimateTest() @@ -245,3 +241,29 @@ def test_bound_animator(): easing=EASING[DEFAULT_EASING], ) assert animator._animations[(id(animate_test), "foo")] == expected + + +def test_animator_on_complete_callback_not_fired_before_duration_ends(): + callback = Mock() + animate_test = AnimateTest() + animator = MockAnimator(Mock()) + + animator.animate(animate_test, "foo", 200, duration=10, on_complete=callback) + + animator._time = 9 + animator() + + assert not callback.called + + +async def test_animator_on_complete_callback_fired_at_duration(): + callback = Mock() + animate_test = AnimateTest() + animator = MockAnimator(Mock()) + + animator.animate(animate_test, "foo", 200, duration=10, on_complete=callback) + + animator._time = 10 + await animator() + + callback.assert_called_once_with()