[App] Finally, time mocking in tests seems to be working! 😅

I had to add a flag in the `_timer` module that allows us to completely disable the  "skip" feature of Timers, though - but it shouldn't cause too much trouble 🤞
This commit is contained in:
Olivier Philippon
2022-05-13 10:35:06 +01:00
parent 74ad6f73fa
commit 15df759197
6 changed files with 211 additions and 138 deletions

View File

@@ -179,10 +179,13 @@ class Animator:
raise AttributeError( raise AttributeError(
f"Can't animate attribute {attribute!r} on {obj!r}; attribute does not exist" f"Can't animate attribute {attribute!r} on {obj!r}; attribute does not exist"
) )
assert not all(
(duration, speed)
), "An Animation should have a duration OR a speed, received both"
if final_value is ...: if final_value is ...:
final_value = value final_value = value
start_time = self._timer.get_time() start_time = self._get_time()
animation_key = (id(obj), attribute) animation_key = (id(obj), attribute)
@@ -233,9 +236,15 @@ class Animator:
if not self._animations: if not self._animations:
self._timer.pause() self._timer.pause()
else: else:
animation_time = self._timer.get_time() animation_time = self._get_time()
animation_keys = list(self._animations.keys()) animation_keys = list(self._animations.keys())
for animation_key in animation_keys: for animation_key in animation_keys:
animation = self._animations[animation_key] animation = self._animations[animation_key]
if animation(animation_time): if animation(animation_time):
del self._animations[animation_key] del self._animations[animation_key]
def _get_time(self) -> float:
"""Get the current wall clock time, via the internal Timer."""
# N.B. We could remove this method and always call `self._timer.get_time()` internally,
# but it's handy to have in mocking situations
return self._timer.get_time()

View File

@@ -19,6 +19,9 @@ from ._types import MessageTarget
TimerCallback = Union[Callable[[], Awaitable[None]], Callable[[], None]] TimerCallback = Union[Callable[[], Awaitable[None]], Callable[[], None]]
# /!\ This should only be changed in an "integration tests" context, in which we mock time
_TIMERS_CAN_SKIP: bool = True
class EventTargetGone(Exception): class EventTargetGone(Exception):
pass pass
@@ -27,8 +30,6 @@ class EventTargetGone(Exception):
@rich_repr @rich_repr
class Timer: class Timer:
_timer_count: int = 1 _timer_count: int = 1
# Used to mock Timers' behaviour in a Textual app's integration test:
_instances: weakref.WeakSet[Timer] = weakref.WeakSet()
def __init__( def __init__(
self, self,
@@ -64,7 +65,6 @@ class Timer:
self._repeat = repeat self._repeat = repeat
self._skip = skip self._skip = skip
self._active = Event() self._active = Event()
Timer._instances.add(self)
if not pause: if not pause:
self._active.set() self._active.set()
@@ -126,11 +126,10 @@ class Timer:
try: try:
while _repeat is None or count <= _repeat: while _repeat is None or count <= _repeat:
next_timer = start + ((count + 1) * _interval) next_timer = start + ((count + 1) * _interval)
now = self.get_time() if self._skip and _TIMERS_CAN_SKIP and next_timer < self.get_time():
if self._skip and next_timer < now:
count += 1 count += 1
continue continue
wait_time = max(0, next_timer - now) wait_time = max(0, next_timer - self.get_time())
if wait_time: if wait_time:
await self._sleep(wait_time) await self._sleep(wait_time)
count += 1 count += 1

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import sys
from rich.console import RenderableType from rich.console import RenderableType
import rich.repr import rich.repr
from rich.style import Style from rich.style import Style
@@ -12,6 +14,14 @@ from ._compositor import Compositor, MapGeometry
from .reactive import Reactive from .reactive import Reactive
from .widget import Widget from .widget import Widget
if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final
# Screen updates will be batched so that they don't happen more often than 20 times per second:
UPDATE_PERIOD: Final = 1 / 20
@rich.repr.auto @rich.repr.auto
class Screen(Widget): class Screen(Widget):
@@ -158,7 +168,9 @@ class Screen(Widget):
self.check_idle() self.check_idle()
def on_mount(self, event: events.Mount) -> None: def on_mount(self, event: events.Mount) -> None:
self._update_timer = self.set_interval(1 / 20, self._on_update, pause=True) self._update_timer = self.set_interval(
UPDATE_PERIOD, self._on_update, pause=True
)
async def on_resize(self, event: events.Resize) -> None: async def on_resize(self, event: events.Resize) -> None:
self.size_updated(event.size, event.virtual_size, event.container_size) self.size_updated(event.size, event.virtual_size, event.container_size)

View File

@@ -177,12 +177,12 @@ class MockAnimator(Animator):
self._time = 0.0 self._time = 0.0
self._on_animation_frame_called = False self._on_animation_frame_called = False
def get_time(self):
return self._time
def on_animation_frame(self): def on_animation_frame(self):
self._on_animation_frame_called = True self._on_animation_frame_called = True
def _get_time(self):
return self._time
def test_animator(): def test_animator():

View File

@@ -1,17 +1,8 @@
from __future__ import annotations from __future__ import annotations
import sys
from typing import Sequence, cast from typing import Sequence, cast
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover
import pytest import pytest
from sandbox.vertical_container import VerticalContainer
from tests.utilities.test_app import AppTest from tests.utilities.test_app import AppTest
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.geometry import Size from textual.geometry import Size
@@ -31,23 +22,19 @@ SCREEN_SIZE = Size(100, 30)
"scroll_to_animate", "scroll_to_animate",
"waiting_duration", "waiting_duration",
"last_screen_expected_placeholder_ids", "last_screen_expected_placeholder_ids",
"last_screen_expected_out_of_viewport_placeholder_ids",
), ),
( (
[SCREEN_SIZE, 10, None, None, 0.01, (0, 1, 2, 3, 4), "others"], [SCREEN_SIZE, 10, None, None, 0.01, (0, 1, 2, 3, 4)],
[SCREEN_SIZE, 10, "placeholder_3", False, 0.01, (0, 1, 2, 3, 4), "others"], [SCREEN_SIZE, 10, "placeholder_3", False, 0.01, (0, 1, 2, 3, 4)],
[SCREEN_SIZE, 10, "placeholder_5", False, 0.01, (1, 2, 3, 4, 5), "others"], [SCREEN_SIZE, 10, "placeholder_5", False, 0.01, (1, 2, 3, 4, 5)],
[SCREEN_SIZE, 10, "placeholder_7", False, 0.01, (3, 4, 5, 6, 7), "others"], [SCREEN_SIZE, 10, "placeholder_7", False, 0.01, (3, 4, 5, 6, 7)],
[SCREEN_SIZE, 10, "placeholder_9", False, 0.01, (5, 6, 7, 8, 9), "others"], [SCREEN_SIZE, 10, "placeholder_9", False, 0.01, (5, 6, 7, 8, 9)],
# N.B. Scroll duration is hard-coded to 0.2 in the `scroll_to_widget` method atm # N.B. Scroll duration is hard-coded to 0.2 in the `scroll_to_widget` method atm
# Waiting for this duration should allow us to see the scroll finished: # Waiting for this duration should allow us to see the scroll finished:
[SCREEN_SIZE, 10, "placeholder_9", True, 0.21, (5, 6, 7, 8, 9), "others"], [SCREEN_SIZE, 10, "placeholder_9", True, 0.21, (5, 6, 7, 8, 9)],
# After having waited for approximately half of the scrolling duration, we should # After having waited for approximately half of the scrolling duration, we should
# see the middle Placeholders as we're scrolling towards the last of them. # see the middle Placeholders as we're scrolling towards the last of them.
# The state of the screen at this "halfway there" timing looks to not be deterministic though, [SCREEN_SIZE, 10, "placeholder_9", True, 0.1, (4, 5, 6, 7, 8)],
# depending on the environment - so let's only assert stuff for the middle placeholders
# and the first and last ones, but without being too specific about the others:
[SCREEN_SIZE, 10, "placeholder_9", True, 0.1, (6, 7, 8), (1, 2, 5, 9)],
), ),
) )
async def test_scroll_to_widget( async def test_scroll_to_widget(
@@ -57,9 +44,19 @@ async def test_scroll_to_widget(
scroll_to_placeholder_id: str | None, scroll_to_placeholder_id: str | None,
waiting_duration: float | None, waiting_duration: float | None,
last_screen_expected_placeholder_ids: Sequence[int], last_screen_expected_placeholder_ids: Sequence[int],
last_screen_expected_out_of_viewport_placeholder_ids: Sequence[int]
| Literal["others"],
): ):
class VerticalContainer(Widget):
CSS = """
VerticalContainer {
layout: vertical;
overflow: hidden auto;
}
VerticalContainer Placeholder {
margin: 1 0;
height: 5;
}
"""
class MyTestApp(AppTest): class MyTestApp(AppTest):
CSS = """ CSS = """
Placeholder { Placeholder {
@@ -77,7 +74,7 @@ async def test_scroll_to_widget(
app = MyTestApp(size=screen_size, test_name="scroll_to_widget") app = MyTestApp(size=screen_size, test_name="scroll_to_widget")
async with app.in_running_state(waiting_duration_post_yield=waiting_duration or 0): async with app.in_running_state(waiting_duration_after_yield=waiting_duration or 0):
if scroll_to_placeholder_id: if scroll_to_placeholder_id:
target_widget_container = cast(Widget, app.query("#root").first()) target_widget_container = cast(Widget, app.query("#root").first())
target_widget = cast( target_widget = cast(
@@ -93,24 +90,24 @@ async def test_scroll_to_widget(
id_: f"placeholder_{id_}" in last_display_capture id_: f"placeholder_{id_}" in last_display_capture
for id_ in range(placeholders_count) for id_ in range(placeholders_count)
} }
print(f"placeholders_visibility_by_id={placeholders_visibility_by_id}")
# Let's start by checking placeholders that should be visible: # Let's start by checking placeholders that should be visible:
for placeholder_id in last_screen_expected_placeholder_ids: for placeholder_id in last_screen_expected_placeholder_ids:
assert ( assert placeholders_visibility_by_id[placeholder_id] is True, (
placeholders_visibility_by_id[placeholder_id] is True f"Placeholder '{placeholder_id}' should be visible but isn't"
), f"Placeholder '{placeholder_id}' should be visible but isn't" f" :: placeholders_visibility_by_id={placeholders_visibility_by_id}"
)
# Ok, now for placeholders that should *not* be visible: # Ok, now for placeholders that should *not* be visible:
if last_screen_expected_out_of_viewport_placeholder_ids == "others": # We're simply going to check that all the placeholders that are not in
# We're simply going to check that all the placeholders that are not in # `last_screen_expected_placeholder_ids` are not on the screen:
# `last_screen_expected_placeholder_ids` are not on the screen: last_screen_expected_out_of_viewport_placeholder_ids = sorted(
last_screen_expected_out_of_viewport_placeholder_ids = sorted( tuple(
tuple( set(range(placeholders_count)) - set(last_screen_expected_placeholder_ids)
set(range(placeholders_count))
- set(last_screen_expected_placeholder_ids)
)
) )
)
for placeholder_id in last_screen_expected_out_of_viewport_placeholder_ids: for placeholder_id in last_screen_expected_out_of_viewport_placeholder_ids:
assert ( assert placeholders_visibility_by_id[placeholder_id] is False, (
placeholders_visibility_by_id[placeholder_id] is False f"Placeholder '{placeholder_id}' should not be visible but is"
), f"Placeholder '{placeholder_id}' should not be visible but is" f" :: placeholders_visibility_by_id={placeholders_visibility_by_id}"
)

View File

@@ -3,19 +3,25 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import io import io
import sys
from math import ceil
from pathlib import Path from pathlib import Path
from time import monotonic from time import monotonic
from typing import AsyncContextManager, cast, ContextManager, Callable from typing import AsyncContextManager, cast, ContextManager
from unittest import mock from unittest import mock
from rich.console import Console from rich.console import Console
from textual import events, errors from textual import events, errors
from textual._timer import Timer
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.driver import Driver from textual.driver import Driver
from textual.geometry import Size from textual.geometry import Size
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
# N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc, # N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc,
# but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/ # but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/
@@ -24,6 +30,12 @@ from textual.geometry import Size
CLEAR_SCREEN_SEQUENCE = "\x1bP=1s\x1b\\" CLEAR_SCREEN_SEQUENCE = "\x1bP=1s\x1b\\"
class MockedTimeMoveClockForward(Protocol):
async def __call__(self, *, seconds: float) -> tuple[float, int]:
"""Returns the new current (mocked) monotonic time and the number of activated Timers"""
...
class AppTest(App): class AppTest(App):
def __init__( def __init__(
self, self,
@@ -64,52 +76,57 @@ class AppTest(App):
def in_running_state( def in_running_state(
self, self,
*, *,
time_mocking_ticks_granularity_fps: int = 60, # i.e. when moving forward by 1 second we'll do it though 60 ticks
waiting_duration_after_initialisation: float = 0.1, waiting_duration_after_initialisation: float = 0.1,
waiting_duration_post_yield: float = 0, waiting_duration_after_yield: float = 0,
time_acceleration: bool = True, ) -> AsyncContextManager[MockedTimeMoveClockForward]:
time_acceleration_factor: float = 10,
# force_timers_tick_after_yield: bool = True,
) -> AsyncContextManager:
async def run_app() -> None: async def run_app() -> None:
await self.process_messages() await self.process_messages()
if time_acceleration:
waiting_duration_after_initialisation /= time_acceleration_factor
waiting_duration_post_yield /= time_acceleration_factor
time_acceleration_context: ContextManager = (
textual_timers_accelerate_time(acceleration_factor=time_acceleration_factor)
if time_acceleration
else contextlib.nullcontext()
)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def get_running_state_context_manager(): async def get_running_state_context_manager():
self._set_active() self._set_active()
with time_acceleration_context:
with mock_textual_timers(
ticks_granularity_fps=time_mocking_ticks_granularity_fps
) as move_time_forward:
run_task = asyncio.create_task(run_app()) run_task = asyncio.create_task(run_app())
timeout_before_yielding_task = asyncio.create_task( await asyncio.sleep(0.001)
asyncio.sleep(waiting_duration_after_initialisation) # timeout_before_yielding_task = asyncio.create_task(
) # asyncio.sleep(waiting_duration_after_initialisation)
done, pending = await asyncio.wait( # )
( # done, pending = await asyncio.wait(
run_task, # (
timeout_before_yielding_task, # run_task,
), # timeout_before_yielding_task,
return_when=asyncio.FIRST_COMPLETED, # ),
) # return_when=asyncio.FIRST_COMPLETED,
if run_task in done or run_task not in pending: # )
raise RuntimeError( # if run_task in done or run_task not in pending:
"TestApp is no longer running after its initialization period" # raise RuntimeError(
) # "TestApp is no longer running after its initialization period"
yield # )
waiting_duration = max(
waiting_duration_post_yield or 0, await move_time_forward(seconds=waiting_duration_after_initialisation)
self.screen._update_timer._interval,
) assert self._driver is not None
await asyncio.sleep(waiting_duration)
self.force_screen_update()
yield move_time_forward
await move_time_forward(seconds=waiting_duration_after_yield)
self.force_screen_update()
# waiting_duration = max(
# waiting_duration_post_yield or 0,
# self.screen._update_timer._interval,
# )
# await asyncio.sleep(waiting_duration)
# if force_timers_tick_after_yield: # if force_timers_tick_after_yield:
# await textual_timers_force_tick() # await textual_timers_force_tick()
assert not run_task.done() assert not run_task.done()
await self.shutdown() await self.shutdown()
@@ -124,27 +141,10 @@ class AppTest(App):
"""Just a commodity shortcut for `async with app.in_running_state(): pass`, for simple cases""" """Just a commodity shortcut for `async with app.in_running_state(): pass`, for simple cases"""
async with self.in_running_state( async with self.in_running_state(
waiting_duration_after_initialisation=waiting_duration_after_initialisation, waiting_duration_after_initialisation=waiting_duration_after_initialisation,
waiting_duration_post_yield=waiting_duration_before_shutdown, waiting_duration_after_yield=waiting_duration_before_shutdown,
): ):
pass pass
def run(self):
raise NotImplementedError(
"Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`"
)
@property
def total_capture(self) -> str | None:
return self.console.file.getvalue()
@property
def last_display_capture(self) -> str | None:
total_capture = self.total_capture
if not total_capture:
return None
last_display_start_index = total_capture.rindex(CLEAR_SCREEN_SEQUENCE)
return total_capture[last_display_start_index:]
def get_char_at(self, x: int, y: int) -> str: def get_char_at(self, x: int, y: int) -> str:
"""Get the character at the given cell or empty string """Get the character at the given cell or empty string
@@ -175,6 +175,34 @@ class AppTest(App):
return segment.text[0] return segment.text[0]
return "" return ""
def force_screen_update(self, *, repaint: bool = True, layout: bool = True) -> None:
try:
self.screen.refresh(repaint=repaint, layout=layout)
self.screen._on_update()
except IndexError:
pass # the app may not have a screen yet
def on_exception(self, error: Exception) -> None:
# In tests we want the errors to be raised, rather than printed to a Console
raise error
def run(self):
raise NotImplementedError(
"Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`"
)
@property
def total_capture(self) -> str | None:
return self.console.file.getvalue()
@property
def last_display_capture(self) -> str | None:
total_capture = self.total_capture
if not total_capture:
return None
last_display_start_index = total_capture.rindex(CLEAR_SCREEN_SEQUENCE)
return total_capture[last_display_start_index:]
@property @property
def console(self) -> ConsoleTest: def console(self) -> ConsoleTest:
return self._console return self._console
@@ -231,43 +259,71 @@ class DriverTest(Driver):
pass pass
async def textual_timers_force_tick() -> None: def mock_textual_timers(
timer_instances_tick_tasks: list[asyncio.Task] = [] *,
for timer in Timer._instances: ticks_granularity_fps: int = 60,
task = asyncio.create_task(timer._tick(next_timer=0, count=0)) ) -> ContextManager[MockedTimeMoveClockForward]:
timer_instances_tick_tasks.append(task) single_tick_duration = 1.0 / ticks_granularity_fps
await asyncio.wait(timer_instances_tick_tasks)
pending_sleep_events: list[tuple[float, asyncio.Event]] = []
def textual_timers_accelerate_time(
*, acceleration_factor: float = 10
) -> ContextManager:
@contextlib.contextmanager @contextlib.contextmanager
def accelerate_time_for_timer_context_manager(): def mock_textual_timers_context_manager():
starting_time = monotonic() # N.B. `start_time` is not used, but it is useful to have when we set breakpoints there :-)
start_time = current_time = monotonic()
# Our replacement for "textual._timer.Timer._sleep": # Our replacement for "textual._timer.Timer._sleep":
async def timer_sleep(duration: float) -> None: async def sleep_mock(duration: float) -> None:
await asyncio.sleep(duration / acceleration_factor) event = asyncio.Event()
target_event_monotonic_time = current_time + duration
pending_sleep_events.append((target_event_monotonic_time, event))
# Ok, let's wait for this Event
# - which can only be "unlocked" by calls to `move_clock_forward()`
await event.wait()
# Our replacement for "textual._timer.Timer.get_time": # Our replacement for "textual._timer.Timer.get_time" and "textual.message.Message._get_time":
def timer_get_time() -> float: def get_time_mock() -> float:
real_now = monotonic() return current_time
real_elapsed_time = real_now - starting_time
accelerated_elapsed_time = real_elapsed_time * acceleration_factor
print(
f"timer_get_time:: accelerated_elapsed_time={accelerated_elapsed_time}"
)
return starting_time + accelerated_elapsed_time
with mock.patch("textual._timer.Timer._sleep") as timer_sleep_mock, mock.patch( async def move_clock_forward(*, seconds: float) -> tuple[float, int]:
"textual._timer.Timer.get_time" nonlocal current_time, start_time
) as timer_get_time_mock, mock.patch(
"textual.message.Message._get_time"
) as message_get_time_mock:
timer_sleep_mock.side_effect = timer_sleep
timer_get_time_mock.side_effect = timer_get_time
message_get_time_mock.side_effect = timer_get_time
yield
return accelerate_time_for_timer_context_manager() ticks_count = ceil(seconds * ticks_granularity_fps)
activated_timers_count_total = 0
for tick_counter in range(ticks_count):
current_time += single_tick_duration
activated_timers_count_total += check_sleep_timers_to_activate()
# Let's give an opportunity to asyncio-related stuff to happen,
# now that we unlocked some occurrences of `await sleep(duration)`:
await asyncio.sleep(0.0001)
return current_time, activated_timers_count_total
def check_sleep_timers_to_activate() -> int:
nonlocal pending_sleep_events
activated_timers_count = 0
for i, (target_event_monotonic_time, event) in enumerate(
pending_sleep_events
):
if target_event_monotonic_time < current_time:
continue
# Right, let's release this waiting event!
event.set()
activated_timers_count += 1
# ...and remove it from our pending sleep events list:
del pending_sleep_events[i]
return activated_timers_count
with mock.patch("textual._timer._TIMERS_CAN_SKIP", new=False), mock.patch(
"textual._timer.Timer._sleep", side_effect=sleep_mock
), mock.patch(
"textual._timer.Timer.get_time", side_effect=get_time_mock
), mock.patch(
"textual.message.Message._get_time", side_effect=get_time_mock
):
yield move_clock_forward
return mock_textual_timers_context_manager()