diff --git a/src/textual/_doc.py b/src/textual/_doc.py index 391c85d12..c976c34b9 100644 --- a/src/textual/_doc.py +++ b/src/textual/_doc.py @@ -1,10 +1,11 @@ from __future__ import annotations import hashlib +import inspect import os import shlex from pathlib import Path -from typing import Iterable, cast +from typing import Awaitable, Callable, Coroutine, Iterable, cast from textual._import_app import import_app from textual.app import App @@ -54,6 +55,7 @@ def take_svg_screenshot( press: Iterable[str] = (), title: str | None = None, terminal_size: tuple[int, int] = (80, 24), + run_before: Callable[[Pilot], Awaitable[None] | None] | None = None, ) -> str: """ @@ -63,11 +65,13 @@ def take_svg_screenshot( press: Key presses to run before taking screenshot. "_" is a short pause. title: The terminal title in the output image. terminal_size: A pair of integers (rows, columns), representing terminal size. + run_before: An arbitrary callable that runs arbitrary code before taking the + screenshot. Use this to simulate complex user interactions with the app + that cannot be simulated by key presses. Returns: An SVG string, showing the content of the terminal window at the time the screenshot was taken. - """ if app is None: @@ -90,7 +94,7 @@ def take_svg_screenshot( cache_key = f"{hash.hexdigest()}.svg" return cache_key - if app_path is not None: + if app_path is not None and run_before is None: screenshot_cache = Path(SCREENSHOT_CACHE) screenshot_cache.mkdir(exist_ok=True) @@ -100,6 +104,10 @@ def take_svg_screenshot( async def auto_pilot(pilot: Pilot) -> None: app = pilot.app + if run_before is not None: + result = run_before(pilot) + if inspect.isawaitable(result): + await result await pilot.press(*press) await pilot.wait_for_scheduled_animations() await pilot.pause() @@ -116,7 +124,7 @@ def take_svg_screenshot( ), ) - if app_path is not None: + if app_path is not None and run_before is None: screenshot_path.write_text(svg) assert svg is not None diff --git a/tests/snapshot_tests/conftest.py b/tests/snapshot_tests/conftest.py index c5d9ee21b..12769f5bd 100644 --- a/tests/snapshot_tests/conftest.py +++ b/tests/snapshot_tests/conftest.py @@ -7,7 +7,7 @@ from datetime import datetime from operator import attrgetter from os import PathLike from pathlib import Path, PurePath -from typing import Union, List, Optional, Callable, Iterable +from typing import Awaitable, Coroutine, Union, List, Optional, Callable, Iterable import pytest from _pytest.config import ExitCode @@ -21,6 +21,7 @@ from syrupy import SnapshotAssertion from textual._doc import take_svg_screenshot from textual._import_app import import_app from textual.app import App +from textual.pilot import Pilot TEXTUAL_SNAPSHOT_SVG_KEY = pytest.StashKey[str]() TEXTUAL_ACTUAL_SVG_KEY = pytest.StashKey[str]() @@ -42,6 +43,7 @@ def snap_compare( app_path: str | PurePath, press: Iterable[str] = ("_",), terminal_size: tuple[int, int] = (80, 24), + run_before: Callable[[Pilot], Awaitable[None] | None] | None = None, ) -> bool: """ Compare a current screenshot of the app running at app_path, with @@ -54,9 +56,12 @@ def snap_compare( test this function is called from. press (Iterable[str]): Key presses to run before taking screenshot. "_" is a short pause. terminal_size (tuple[int, int]): A pair of integers (WIDTH, HEIGHT), representing terminal size. + run_before: An arbitrary callable that runs arbitrary code before taking the + screenshot. Use this to simulate complex user interactions with the app + that cannot be simulated by key presses. Returns: - bool: True if the screenshot matches the snapshot. + Whether the screenshot matches the snapshot. """ node = request.node path = Path(app_path) @@ -74,6 +79,7 @@ def snap_compare( app=app, press=press, terminal_size=terminal_size, + run_before=run_before, ) result = snapshot == actual_screenshot