Animation callback (#683)

* Add on_complete callback to Animations

* Add tests for on_complete Animation callbacks

* Update __textual_animation__ signature to include callback

* Support awaitable callbacks for animator on_complete

* Import tidying

* Update animator tests
This commit is contained in:
darrenburns
2022-08-19 14:55:53 +01:00
committed by GitHub
parent 9eea01f5a1
commit 3749412d2f
8 changed files with 83 additions and 53 deletions

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from rich.console import RenderableType from rich.console import RenderableType
from rich.panel import Panel
from textual import events from textual import events
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.layout import Horizontal, Vertical
from textual.widget import Widget from textual.widget import Widget
@@ -18,37 +18,22 @@ class Box(Widget, can_focus=True):
super().__init__(*children, id=id, classes=classes) super().__init__(*children, id=id, classes=classes)
def render(self) -> RenderableType: def render(self) -> RenderableType:
return Panel("Box") return "Box"
class JustABox(App): class JustABox(App):
def compose(self) -> ComposeResult: def compose(self) -> ComposeResult:
yield Horizontal( self.box = Box()
Vertical( yield self.box
Box(id="box1", classes="box"),
Box(id="box2", classes="box"),
# Box(id="box3", classes="box"),
# Box(id="box4", classes="box"),
# Box(id="box5", classes="box"),
# Box(id="box6", classes="box"),
# Box(id="box7", classes="box"),
# Box(id="box8", classes="box"),
# Box(id="box9", classes="box"),
# Box(id="box10", classes="box"),
id="left_pane",
),
Box(id="middle_pane"),
Vertical(
Box(id="boxa", classes="box"),
Box(id="boxb", classes="box"),
Box(id="boxc", classes="box"),
id="right_pane",
),
id="horizontal",
)
def key_p(self): def key_a(self):
print(self.query("#horizontal").first().styles.layout) self.animator.animate(
self.box.styles,
"opacity",
value=0.0,
duration=2.0,
on_complete=self.box.remove,
)
async def on_key(self, event: events.Key) -> None: async def on_key(self, event: events.Key) -> None:
await self.dispatch_key(event) await self.dispatch_key(event)

View File

@@ -8,9 +8,10 @@ from typing import Any, Callable, TypeVar
from dataclasses import dataclass from dataclasses import dataclass
from . import _clock from . import _clock
from ._callback import invoke
from ._easing import DEFAULT_EASING, EASING from ._easing import DEFAULT_EASING, EASING
from ._timer import Timer from ._timer import Timer
from ._types import MessageTarget from ._types import MessageTarget, CallbackType
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
@@ -30,8 +31,20 @@ class Animatable(Protocol):
class Animation(ABC): class Animation(ABC):
on_complete: CallbackType | None = None
"""Callback to run after animation completes"""
@abstractmethod @abstractmethod
def __call__(self, time: float) -> bool: # pragma: no cover def __call__(self, time: float) -> bool: # pragma: no cover
"""Call the animation, return a boolean indicating whether animation is in-progress or complete.
Args:
time (float): The current timestamp
Returns:
bool: True if the animation has finished, otherwise False.
"""
raise NotImplementedError("") raise NotImplementedError("")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
@@ -48,6 +61,7 @@ class SimpleAnimation(Animation):
end_value: float | Animatable end_value: float | Animatable
final_value: object final_value: object
easing: EasingFunction easing: EasingFunction
on_complete: CallbackType | None = None
def __call__(self, time: float) -> bool: def __call__(self, time: float) -> bool:
@@ -109,6 +123,7 @@ class BoundAnimator:
duration: float | None = None, duration: float | None = None,
speed: float | None = None, speed: float | None = None,
easing: EasingFunction | str = DEFAULT_EASING, easing: EasingFunction | str = DEFAULT_EASING,
on_complete: CallbackType | None = None,
) -> None: ) -> None:
easing_function = EASING[easing] if isinstance(easing, str) else easing easing_function = EASING[easing] if isinstance(easing, str) else easing
return self._animator.animate( return self._animator.animate(
@@ -119,6 +134,7 @@ class BoundAnimator:
duration=duration, duration=duration,
speed=speed, speed=speed,
easing=easing_function, easing=easing_function,
on_complete=on_complete,
) )
@@ -163,6 +179,7 @@ class Animator:
duration: float | None = None, duration: float | None = None,
speed: float | None = None, speed: float | None = None,
easing: EasingFunction | str = DEFAULT_EASING, easing: EasingFunction | str = DEFAULT_EASING,
on_complete: CallbackType | None = None,
) -> None: ) -> None:
"""Animate an attribute to a new value. """Animate an attribute to a new value.
@@ -201,6 +218,7 @@ class Animator:
duration=duration, duration=duration,
speed=speed, speed=speed,
easing=easing_function, easing=easing_function,
on_complete=on_complete,
) )
if animation is None: if animation is None:
start_value = getattr(obj, attribute) start_value = getattr(obj, attribute)
@@ -223,6 +241,7 @@ class Animator:
end_value=value, end_value=value,
final_value=final_value, final_value=final_value,
easing=easing_function, easing=easing_function,
on_complete=on_complete,
) )
assert animation is not None, "animation expected to be non-None" assert animation is not None, "animation expected to be non-None"
@@ -233,7 +252,7 @@ class Animator:
self._animations[animation_key] = animation self._animations[animation_key] = animation
self._timer.resume() self._timer.resume()
def __call__(self) -> None: async def __call__(self) -> None:
if not self._animations: if not self._animations:
self._timer.pause() self._timer.pause()
else: else:
@@ -241,7 +260,11 @@ class Animator:
animation_keys = list(self._animations.keys()) animation_keys = list(self._animations.keys())
for animation_key in animation_keys: for animation_key in animation_keys:
animation = self._animations[animation_key] animation = self._animations[animation_key]
if animation(animation_time): animation_complete = animation(animation_time)
if animation_complete:
completion_callback = animation.on_complete
if completion_callback is not None:
await invoke(completion_callback)
del self._animations[animation_key] del self._animations[animation_key]
def _get_time(self) -> float: def _get_time(self) -> float:

View File

@@ -1,5 +1,6 @@
import sys import sys
from typing import Awaitable, Callable, List, Optional, TYPE_CHECKING from typing import Awaitable, Callable, List, TYPE_CHECKING, Union
from rich.segment import Segment from rich.segment import Segment
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
@@ -11,8 +12,6 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .message import Message from .message import Message
Callback = Callable[[], None]
class MessageTarget(Protocol): class MessageTarget(Protocol):
async def post_message(self, message: "Message") -> bool: async def post_message(self, message: "Message") -> bool:
@@ -34,5 +33,5 @@ class EventTarget(Protocol):
MessageHandler = Callable[["Message"], Awaitable] MessageHandler = Callable[["Message"], Awaitable]
Lines = List[List[Segment]] Lines = List[List[Segment]]
CallbackType = Union[Callable[[], Awaitable[None]], Callable[[], None]]

View File

@@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Callable
from .. import events, log from .._types import CallbackType
from ..geometry import Offset from ..geometry import Offset
from .._animator import Animation from .._animator import Animation
from .scalar import ScalarOffset from .scalar import ScalarOffset
@@ -25,6 +25,7 @@ class ScalarAnimation(Animation):
duration: float | None, duration: float | None,
speed: float | None, speed: float | None,
easing: EasingFunction, easing: EasingFunction,
on_complete: CallbackType | None = None,
): ):
assert ( assert (
speed is not None or duration is not None speed is not None or duration is not None
@@ -35,6 +36,7 @@ class ScalarAnimation(Animation):
self.attribute = attribute self.attribute = attribute
self.final_value = value self.final_value = value
self.easing = easing self.easing = easing
self.on_complete = on_complete
size = widget.outer_size size = widget.outer_size
viewport = widget.app.size viewport = widget.app.size
@@ -55,7 +57,6 @@ class ScalarAnimation(Animation):
eased_factor = self.easing(factor) eased_factor = self.easing(factor)
if eased_factor >= 1: if eased_factor >= 1:
offset = self.final_value
setattr(self.styles, self.attribute, self.final_value) setattr(self.styles, self.attribute, self.final_value)
return True return True

View File

@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, cast
import rich.repr import rich.repr
from rich.style import Style from rich.style import Style
from textual._types import CallbackType
from .._animator import Animation, EasingFunction from .._animator import Animation, EasingFunction
from ..color import Color from ..color import Color
from ..geometry import Offset, Spacing from ..geometry import Offset, Spacing
@@ -579,9 +580,9 @@ class Styles(StylesBase):
duration: float | None, duration: float | None,
speed: float | None, speed: float | None,
easing: EasingFunction, easing: EasingFunction,
on_complete: CallbackType | None = None,
) -> Animation | None: ) -> Animation | None:
from ..widget import Widget # from ..widget import Widget
# node = self.node # node = self.node
# assert isinstance(self.node, Widget) # assert isinstance(self.node, Widget)
if isinstance(value, ScalarOffset): if isinstance(value, ScalarOffset):
@@ -594,6 +595,7 @@ class Styles(StylesBase):
duration=duration, duration=duration,
speed=speed, speed=speed,
easing=easing, easing=easing,
on_complete=on_complete,
) )
return None return None

View File

@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Awaitable, Union from typing import TYPE_CHECKING
import rich.repr import rich.repr
from ._types import CallbackType
from .message import Message from .message import Message
@@ -11,9 +12,6 @@ if TYPE_CHECKING:
from .widget import Widget from .widget import Widget
CallbackType = Union[Callable[[], Awaitable[None]], Callable[[], None]]
@rich.repr.auto @rich.repr.auto
class Update(Message, verbosity=3): class Update(Message, verbosity=3):
def __init__(self, sender: MessagePump, widget: Widget): def __init__(self, sender: MessagePump, widget: Widget):

View File

@@ -12,7 +12,7 @@ from ._callback import invoke
from .geometry import Offset, Region, Size from .geometry import Offset, Region, Size
from ._compositor import Compositor, MapGeometry from ._compositor import Compositor, MapGeometry
from .messages import CallbackType from ._types import CallbackType
from .reactive import Reactive from .reactive import Reactive
from .renderables.blank import Blank from .renderables.blank import Blank
from ._timer import Timer from ._timer import Timer

View File

@@ -1,12 +1,10 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from textual._animator import Animator, SimpleAnimation from textual._animator import Animator, SimpleAnimation
from textual._easing import EASING, DEFAULT_EASING from textual._easing import EASING, DEFAULT_EASING
@@ -184,8 +182,7 @@ class MockAnimator(Animator):
return self._time return self._time
def test_animator(): async def test_animator():
target = Mock() target = Mock()
animator = MockAnimator(target) animator = MockAnimator(target)
animate_test = AnimateTest() animate_test = AnimateTest()
@@ -206,11 +203,11 @@ def test_animator():
assert animator._animations[(id(animate_test), "foo")] == expected assert animator._animations[(id(animate_test), "foo")] == expected
assert not animator._on_animation_frame_called assert not animator._on_animation_frame_called
animator() await animator()
assert animate_test.foo == 0 assert animate_test.foo == 0
animator._time = 5 animator._time = 5
animator() await animator()
assert animate_test.foo == 50 assert animate_test.foo == 50
# New animation in the middle of an existing one # New animation in the middle of an existing one
@@ -218,12 +215,11 @@ def test_animator():
assert animate_test.foo == 50 assert animate_test.foo == 50
animator._time = 6 animator._time = 6
animator() await animator()
assert animate_test.foo == 200 assert animate_test.foo == 200
def test_bound_animator(): def test_bound_animator():
target = Mock() target = Mock()
animator = MockAnimator(target) animator = MockAnimator(target)
animate_test = AnimateTest() animate_test = AnimateTest()
@@ -245,3 +241,29 @@ def test_bound_animator():
easing=EASING[DEFAULT_EASING], easing=EASING[DEFAULT_EASING],
) )
assert animator._animations[(id(animate_test), "foo")] == expected assert animator._animations[(id(animate_test), "foo")] == expected
def test_animator_on_complete_callback_not_fired_before_duration_ends():
callback = Mock()
animate_test = AnimateTest()
animator = MockAnimator(Mock())
animator.animate(animate_test, "foo", 200, duration=10, on_complete=callback)
animator._time = 9
animator()
assert not callback.called
async def test_animator_on_complete_callback_fired_at_duration():
callback = Mock()
animate_test = AnimateTest()
animator = MockAnimator(Mock())
animator.animate(animate_test, "foo", 200, duration=10, on_complete=callback)
animator._time = 10
await animator()
callback.assert_called_once_with()