Merge pull request #1039 from Textualize/unmount

unmount event
This commit is contained in:
Will McGugan
2022-10-31 13:36:30 +00:00
committed by GitHub
26 changed files with 596 additions and 726 deletions

View File

@@ -7,9 +7,24 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
## [0.2.2] - Unreleased
### Fixed
- Fixed issue where scrollbars weren't being unmounted
### Changed
- DOMQuery now raises InvalidQueryFormat in response to invalid query strings, rather than cryptic CSS error
- Dropped quit_after, screenshot, and screenshot_title from App.run, which can all be done via auto_pilot
- Widgets are now closed in reversed DOM order
### Added
- Added Unmount event
- Added App.run_async method
- Added App.run_test context manager
- Added auto_pilot to App.run and App.run_async
- Added Widget._get_virtual_dom to get scrollbars
- Added size parameter to run and run_async
## [0.2.1] - 2022-10-23

1
docs/reference/pilot.md Normal file
View File

@@ -0,0 +1 @@
::: textual.pilot

View File

@@ -107,6 +107,7 @@ nav:
- "reference/index.md"
- "reference/message_pump.md"
- "reference/message.md"
- "reference/pilot.md"
- "reference/query.md"
- "reference/reactive.md"
- "reference/screen.md"

View File

@@ -51,7 +51,9 @@ class Logger:
try:
app = active_app.get()
except LookupError:
raise LoggerError("Unable to log without an active app.") from None
print_args = (*args, *[f"{key}={value!r}" for key, value in kwargs.items()])
print(*print_args)
return
if app.devtools is None or not app.devtools.is_connected:
return

View File

@@ -5,6 +5,7 @@ import shlex
from typing import Iterable
from textual.app import App
from textual.pilot import Pilot
from textual._import_app import import_app
@@ -18,7 +19,7 @@ def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str
path = cmd[0]
_press = attrs.get("press", None)
press = [*_press.split(",")] if _press else ["_"]
press = [*_press.split(",")] if _press else []
title = attrs.get("title")
print(f"screenshotting {path!r}")
@@ -28,7 +29,7 @@ def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str
rows = int(attrs.get("lines", 24))
columns = int(attrs.get("columns", 80))
svg = take_svg_screenshot(
None, path, press, title, terminal_size=(rows, columns)
None, path, press, title, terminal_size=(columns, rows)
)
finally:
os.chdir(cwd)
@@ -45,9 +46,9 @@ def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str
def take_svg_screenshot(
app: App | None = None,
app_path: str | None = None,
press: Iterable[str] = ("_",),
press: Iterable[str] = (),
title: str | None = None,
terminal_size: tuple[int, int] = (24, 80),
terminal_size: tuple[int, int] = (80, 24),
) -> str:
"""
@@ -63,25 +64,29 @@ def take_svg_screenshot(
the screenshot was taken.
"""
rows, columns = terminal_size
os.environ["COLUMNS"] = str(columns)
os.environ["LINES"] = str(rows)
if app is None:
assert app_path is not None
app = import_app(app_path)
assert app is not None
if title is None:
title = app.title
app.run(
quit_after=5,
press=press or ["ctrl+c"],
async def auto_pilot(pilot: Pilot) -> None:
app = pilot.app
await pilot.press(*press)
svg = app.export_screenshot(title=title)
app.exit(svg)
svg = app.run(
headless=True,
screenshot=True,
screenshot_title=title,
auto_pilot=auto_pilot,
size=terminal_size,
)
svg = app._screenshot
assert svg is not None
return svg

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
from asyncio import Task
from contextlib import asynccontextmanager
import inspect
import io
import os
@@ -12,7 +14,18 @@ from contextlib import redirect_stderr, redirect_stdout
from datetime import datetime
from pathlib import Path, PurePath
from time import perf_counter
from typing import Any, Generic, Iterable, Type, TYPE_CHECKING, TypeVar, cast, Union
from typing import (
Any,
Callable,
Coroutine,
Generic,
Iterable,
Type,
TYPE_CHECKING,
TypeVar,
cast,
Union,
)
from weakref import WeakSet, WeakValueDictionary
from ._ansi_sequences import SYNC_END, SYNC_START
@@ -51,7 +64,12 @@ from .widget import AwaitMount, Widget
if TYPE_CHECKING:
from .devtools.client import DevtoolsClient
from .pilot import Pilot
if sys.version_info >= (3, 10):
from typing import TypeAlias
else: # pragma: no cover
from typing_extensions import TypeAlias
PLATFORM = platform.system()
WINDOWS = PLATFORM == "Windows"
@@ -89,6 +107,9 @@ ComposeResult = Iterable[Widget]
RenderResult = RenderableType
AutopilotCallbackType: TypeAlias = "Callable[[Pilot], Coroutine[Any, Any, None]]"
class AppError(Exception):
pass
@@ -170,7 +191,7 @@ class App(Generic[ReturnType], DOMNode):
if no_color is not None:
self._filter = Monochrome()
self.console = Console(
file=(_NullFile() if self.is_headless else sys.__stdout__),
file=sys.__stdout__ if sys.__stdout__ is not None else _NullFile(),
markup=False,
highlight=False,
emoji=False,
@@ -241,6 +262,11 @@ class App(Generic[ReturnType], DOMNode):
)
self._screenshot: str | None = None
@property
def return_value(self) -> ReturnType | None:
"""Get the return type."""
return self._return_value
def animate(
self,
attribute: str,
@@ -295,7 +321,7 @@ class App(Generic[ReturnType], DOMNode):
bool: True if the app is in headless mode.
"""
return "headless" in self.features
return False if self._driver is None else self._driver.is_headless
@property
def screen_stack(self) -> list[Screen]:
@@ -314,7 +340,7 @@ class App(Generic[ReturnType], DOMNode):
result (ReturnType | None, optional): Return value. Defaults to None.
"""
self._return_value = result
self._close_messages_no_wait()
self.post_message_no_wait(messages.ExitApp(sender=self))
@property
def focused(self) -> Widget | None:
@@ -418,7 +444,11 @@ class App(Generic[ReturnType], DOMNode):
Returns:
Size: Size of the terminal
"""
return Size(*self.console.size)
if self._driver is not None and self._driver._size is not None:
width, height = self._driver._size
else:
width, height = self.console.size
return Size(width, height)
@property
def log(self) -> Logger:
@@ -500,10 +530,11 @@ class App(Generic[ReturnType], DOMNode):
to use app title. Defaults to None.
"""
assert self._driver is not None, "App must be running"
width, height = self.size
console = Console(
width=self.console.width,
height=self.console.height,
width=width,
height=height,
file=io.StringIO(),
force_terminal=True,
color_system="truecolor",
@@ -567,96 +598,170 @@ class App(Generic[ReturnType], DOMNode):
keys, action, description, show=show, key_display=key_display
)
async def _press_keys(self, keys: Iterable[str]) -> None:
"""A task to send key events."""
app = self
driver = app._driver
assert driver is not None
await asyncio.sleep(0.02)
for key in keys:
if key == "_":
print("(pause 50ms)")
await asyncio.sleep(0.05)
elif key.startswith("wait:"):
_, wait_ms = key.split(":")
print(f"(pause {wait_ms}ms)")
await asyncio.sleep(float(wait_ms) / 1000)
else:
if len(key) == 1 and not key.isalnum():
key = (
unicodedata.name(key)
.lower()
.replace("-", "_")
.replace(" ", "_")
)
original_key = REPLACED_KEYS.get(key, key)
char: str | None
try:
char = unicodedata.lookup(original_key.upper().replace("_", " "))
except KeyError:
char = key if len(key) == 1 else None
print(f"press {key!r} (char={char!r})")
key_event = events.Key(app, key, char)
driver.send_event(key_event)
# TODO: A bit of a fudge - extra sleep after tabbing to help guard against race
# condition between widget-level key handling and app/screen level handling.
# More information here: https://github.com/Textualize/textual/issues/1009
# This conditional sleep can be removed after that issue is closed.
if key == "tab":
await asyncio.sleep(0.05)
await asyncio.sleep(0.02)
await app._animator.wait_for_idle()
@asynccontextmanager
async def run_test(
self,
*,
headless: bool = True,
size: tuple[int, int] | None = (80, 24),
):
"""An asynchronous context manager for testing app.
Args:
headless (bool, optional): Run in headless mode (no output or input). Defaults to True.
size (tuple[int, int] | None, optional): Force terminal size to `(WIDTH, HEIGHT)`,
or None to auto-detect. Defaults to None.
"""
from .pilot import Pilot
app = self
app_ready_event = asyncio.Event()
def on_app_ready() -> None:
"""Called when app is ready to process events."""
app_ready_event.set()
async def run_app(app) -> None:
await app._process_messages(
ready_callback=on_app_ready,
headless=headless,
terminal_size=size,
)
# Launch the app in the "background"
app_task = asyncio.create_task(run_app(app))
# Wait until the app has performed all startup routines.
await app_ready_event.wait()
# Context manager returns pilot object to manipulate the app
yield Pilot(app)
# Shutdown the app cleanly
await app._shutdown()
await app_task
async def run_async(
self,
*,
headless: bool = False,
size: tuple[int, int] | None = None,
auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None:
"""Run the app asynchronously.
Args:
headless (bool, optional): Run in headless mode (no output). Defaults to False.
size (tuple[int, int] | None, optional): Force terminal size to `(WIDTH, HEIGHT)`,
or None to auto-detect. Defaults to None.
auto_pilot (AutopilotCallbackType): An auto pilot coroutine.
Returns:
ReturnType | None: App return value.
"""
from .pilot import Pilot
app = self
auto_pilot_task: Task | None = None
async def app_ready() -> None:
"""Called by the message loop when the app is ready."""
nonlocal auto_pilot_task
if auto_pilot is not None:
async def run_auto_pilot(
auto_pilot: AutopilotCallbackType, pilot: Pilot
) -> None:
try:
await auto_pilot(pilot)
except Exception:
app.exit()
raise
pilot = Pilot(app)
auto_pilot_task = asyncio.create_task(run_auto_pilot(auto_pilot, pilot))
try:
await app._process_messages(
ready_callback=None if auto_pilot is None else app_ready,
headless=headless,
terminal_size=size,
)
finally:
if auto_pilot_task is not None:
await auto_pilot_task
await app._shutdown()
return app.return_value
def run(
self,
*,
quit_after: float | None = None,
headless: bool = False,
press: Iterable[str] | None = None,
screenshot: bool = False,
screenshot_title: str | None = None,
size: tuple[int, int] | None = None,
auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None:
"""The main entry point for apps.
"""Run the app.
Args:
quit_after (float | None, optional): Quit after a given number of seconds, or None
to run forever. Defaults to None.
headless (bool, optional): Run in "headless" mode (don't write to stdout).
press (str, optional): An iterable of keys to simulate being pressed.
screenshot (bool, optional): Take a screenshot after pressing keys (svg data stored in self._screenshot). Defaults to False.
screenshot_title (str | None, optional): Title of screenshot, or None to use App title. Defaults to None.
headless (bool, optional): Run in headless mode (no output). Defaults to False.
size (tuple[int, int] | None, optional): Force terminal size to `(WIDTH, HEIGHT)`,
or None to auto-detect. Defaults to None.
auto_pilot (AutopilotCallbackType): An auto pilot coroutine.
Returns:
ReturnType | None: The return value specified in `App.exit` or None if exit wasn't called.
ReturnType | None: App return value.
"""
if headless:
self.features = cast(
"frozenset[FeatureFlag]", self.features.union({"headless"})
)
async def run_app() -> None:
if quit_after is not None:
self.set_timer(quit_after, self.shutdown)
if press is not None:
app = self
async def press_keys() -> None:
"""A task to send key events."""
assert press
driver = app._driver
assert driver is not None
await asyncio.sleep(0.02)
for key in press:
if key == "_":
print("(pause 50ms)")
await asyncio.sleep(0.05)
elif key.startswith("wait:"):
_, wait_ms = key.split(":")
print(f"(pause {wait_ms}ms)")
await asyncio.sleep(float(wait_ms) / 1000)
else:
if len(key) == 1 and not key.isalnum():
key = (
unicodedata.name(key)
.lower()
.replace("-", "_")
.replace(" ", "_")
)
original_key = REPLACED_KEYS.get(key, key)
try:
char = unicodedata.lookup(
original_key.upper().replace("_", " ")
)
except KeyError:
char = key if len(key) == 1 else None
print(f"press {key!r} (char={char!r})")
key_event = events.Key(self, key, char)
driver.send_event(key_event)
# TODO: A bit of a fudge - extra sleep after tabbing to help guard against race
# condition between widget-level key handling and app/screen level handling.
# More information here: https://github.com/Textualize/textual/issues/1009
# This conditional sleep can be removed after that issue is closed.
if key == "tab":
await asyncio.sleep(0.05)
await asyncio.sleep(0.02)
await app._animator.wait_for_idle()
await asyncio.sleep(0.05)
if screenshot:
self._screenshot = self.export_screenshot(
title=screenshot_title
)
await self.shutdown()
async def press_keys_task():
"""Press some keys in the background."""
asyncio.create_task(press_keys())
await self._process_messages(ready_callback=press_keys_task)
else:
await self._process_messages()
"""Run the app."""
await self.run_async(
headless=headless,
size=size,
auto_pilot=auto_pilot,
)
if _ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED:
# N.B. This doesn't work with Python<3.10, as we end up with 2 event loops:
@@ -665,8 +770,7 @@ class App(Generic[ReturnType], DOMNode):
# However, this works with Python<3.10:
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(run_app())
return self._return_value
return self.return_value
async def _on_css_change(self) -> None:
"""Called when the CSS changes (if watch_css is True)."""
@@ -993,7 +1097,10 @@ class App(Generic[ReturnType], DOMNode):
self._exit_renderables.clear()
async def _process_messages(
self, ready_callback: CallbackType | None = None
self,
ready_callback: CallbackType | None = None,
headless: bool = False,
terminal_size: tuple[int, int] | None = None,
) -> None:
self._set_active()
@@ -1039,22 +1146,31 @@ class App(Generic[ReturnType], DOMNode):
self.log.system("[b green]STARTED[/]", self.css_monitor)
async def run_process_messages():
"""The main message loop, invoke below."""
async def invoke_ready_callback() -> None:
if ready_callback is not None:
ready_result = ready_callback()
if inspect.isawaitable(ready_result):
await ready_result
try:
await self._dispatch_message(events.Compose(sender=self))
await self._dispatch_message(events.Mount(sender=self))
try:
await self._dispatch_message(events.Compose(sender=self))
await self._dispatch_message(events.Mount(sender=self))
finally:
self._mounted_event.set()
Reactive._initialize_object(self)
self.stylesheet.update(self)
self.refresh()
await self.animator.start()
finally:
self._mounted_event.set()
Reactive._initialize_object(self)
self.stylesheet.update(self)
self.refresh()
await self.animator.start()
await self._ready()
if ready_callback is not None:
await ready_callback()
await self._ready()
await invoke_ready_callback()
self._running = True
@@ -1068,7 +1184,6 @@ class App(Generic[ReturnType], DOMNode):
await timer.stop()
await self.animator.stop()
await self._close_all()
self._running = True
try:
@@ -1078,13 +1193,13 @@ class App(Generic[ReturnType], DOMNode):
driver: Driver
driver_class = cast(
"type[Driver]",
HeadlessDriver if self.is_headless else self.driver_class,
HeadlessDriver if headless else self.driver_class,
)
driver = self._driver = driver_class(self.console, self)
driver = self._driver = driver_class(self.console, self, size=terminal_size)
driver.start_application_mode()
try:
if self.is_headless:
if headless:
await run_process_messages()
else:
if self.devtools is not None:
@@ -1106,11 +1221,6 @@ class App(Generic[ReturnType], DOMNode):
driver.stop_application_mode()
except Exception as error:
self._handle_exception(error)
finally:
self._running = False
self._print_error_renderables()
if self.devtools is not None and self.devtools.is_connected:
await self._disconnect_devtools()
async def _pre_process(self) -> None:
pass
@@ -1135,7 +1245,7 @@ class App(Generic[ReturnType], DOMNode):
"""Used by docs plugin."""
svg = self.export_screenshot(title=screenshot_title)
self._screenshot = svg # type: ignore
await self.shutdown()
self.exit()
self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer")
@@ -1219,8 +1329,10 @@ class App(Generic[ReturnType], DOMNode):
parent (Widget): The parent of the Widget.
widget (Widget): The Widget to start.
"""
widget._attach(parent)
widget._start_messages()
self.app._registry.add(widget)
def is_mounted(self, widget: Widget) -> bool:
"""Check if a widget is mounted.
@@ -1234,17 +1346,43 @@ class App(Generic[ReturnType], DOMNode):
return widget in self._registry
async def _close_all(self) -> None:
while self._registry:
child = self._registry.pop()
"""Close all message pumps."""
# Close all screens on the stack
for screen in self._screen_stack:
if screen._running:
await self._prune_node(screen)
self._screen_stack.clear()
# Close pre-defined screens
for screen in self.SCREENS.values():
if screen._running:
await self._prune_node(screen)
# Close any remaining nodes
# Should be empty by now
remaining_nodes = list(self._registry)
for child in remaining_nodes:
await child._close_messages()
async def shutdown(self):
await self._disconnect_devtools()
async def _shutdown(self) -> None:
driver = self._driver
self._running = False
if driver is not None:
driver.disable_input()
await self._close_all()
await self._close_messages()
await self._dispatch_message(events.Unmount(sender=self))
self._print_error_renderables()
if self.devtools is not None and self.devtools.is_connected:
await self._disconnect_devtools()
async def _on_exit_app(self) -> None:
await self._message_queue.put(None)
def refresh(self, *, repaint: bool = True, layout: bool = False) -> None:
if self._screen_stack:
self.screen.refresh(repaint=repaint, layout=layout)
@@ -1498,18 +1636,61 @@ class App(Generic[ReturnType], DOMNode):
[to_remove for to_remove in remove_widgets if to_remove.can_focus],
)
for child in remove_widgets:
await child._close_messages()
self._unregister(child)
await self._prune_node(widget)
if parent is not None:
parent.refresh(layout=True)
def _walk_children(self, root: Widget) -> Iterable[list[Widget]]:
"""Walk children depth first, generating widgets and a list of their siblings.
Returns:
Iterable[list[Widget]]: The child widgets of root.
"""
stack: list[Widget] = [root]
pop = stack.pop
push = stack.append
while stack:
widget = pop()
if widget.children:
yield [*widget.children, *widget._get_virtual_dom()]
for child in widget.children:
push(child)
async def _prune_node(self, root: Widget) -> None:
"""Remove a node and its children. Children are removed before parents.
Args:
root (Widget): Node to remove.
"""
# Pruning a node that has been removed is a no-op
if root not in self._registry:
return
node_children = list(self._walk_children(root))
for children in reversed(node_children):
# Closing children can be done asynchronously.
close_messages = [
child._close_messages() for child in children if child._running
]
# TODO: What if a message pump refuses to exit?
if close_messages:
await asyncio.gather(*close_messages)
for child in children:
self._unregister(child)
await root._close_messages()
self._unregister(root)
async def action_check_bindings(self, key: str) -> None:
await self.check_bindings(key)
async def action_quit(self) -> None:
"""Quit the app as soon as possible."""
await self.shutdown()
self.exit()
async def action_bang(self) -> None:
1 / 0

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import click
from importlib_metadata import version
from textual.pilot import Pilot
from textual._import_app import import_app, AppFail
@@ -84,7 +85,12 @@ def run_app(import_name: str, dev: bool, press: str) -> None:
sys.exit(1)
press_keys = press.split(",") if press else None
result = app.run(press=press_keys)
async def run_press_keys(pilot: Pilot) -> None:
if press_keys is not None:
await pilot.press(*press_keys)
result = app.run(auto_pilot=run_press_keys)
if result is not None:
from rich.console import Console

View File

@@ -13,14 +13,25 @@ if TYPE_CHECKING:
class Driver(ABC):
def __init__(
self, console: "Console", target: "MessageTarget", debug: bool = False
self,
console: "Console",
target: "MessageTarget",
*,
debug: bool = False,
size: tuple[int, int] | None = None,
) -> None:
self.console = console
self._target = target
self._debug = debug
self._size = size
self._loop = asyncio.get_running_loop()
self._mouse_down_time = _clock.get_time_no_wait()
@property
def is_headless(self) -> bool:
"""Check if the driver is 'headless'"""
return False
def send_event(self, event: events.Event) -> None:
asyncio.run_coroutine_threadsafe(
self._target.post_message(event), loop=self._loop

View File

@@ -9,7 +9,13 @@ from .. import events
class HeadlessDriver(Driver):
"""A do-nothing driver for testing."""
@property
def is_headless(self) -> bool:
return True
def _get_terminal_size(self) -> tuple[int, int]:
if self._size is not None:
return self._size
width: int | None = 80
height: int | None = 25
import shutil

View File

@@ -30,9 +30,14 @@ class LinuxDriver(Driver):
"""Powers display and input for Linux / MacOS"""
def __init__(
self, console: "Console", target: "MessageTarget", debug: bool = False
self,
console: "Console",
target: "MessageTarget",
*,
debug: bool = False,
size: tuple[int, int] | None = None,
) -> None:
super().__init__(console, target, debug)
super().__init__(console, target, debug=debug, size=size)
self.fileno = sys.stdin.fileno()
self.attrs_before: list[Any] | None = None
self.exit_event = Event()

View File

@@ -18,9 +18,14 @@ class WindowsDriver(Driver):
"""Powers display and input for Windows."""
def __init__(
self, console: "Console", target: "MessageTarget", debug: bool = False
self,
console: "Console",
target: "MessageTarget",
*,
debug: bool = False,
size: tuple[int, int] | None = None,
) -> None:
super().__init__(console, target, debug)
super().__init__(console, target, debug=debug, size=size)
self.in_fileno = sys.stdin.fileno()
self.out_fileno = sys.stdout.fileno()

View File

@@ -119,10 +119,14 @@ class Compose(Event, bubble=False, verbose=True):
"""Sent to a widget to request it to compose and mount children."""
class Mount(Event, bubble=False, verbose=True):
class Mount(Event, bubble=False, verbose=False):
"""Sent when a widget is *mounted* and may receive messages."""
class Unmount(Mount, bubble=False, verbose=False):
"""Sent when a widget is unmounted and may not longer receive messages."""
class Remove(Event, bubble=False):
"""Sent to a widget to ask it to remove itself from the DOM."""

View File

@@ -155,7 +155,9 @@ class MessagePump(metaclass=MessagePumpMeta):
return self._pending_message
finally:
self._pending_message = None
message = await self._message_queue.get()
if message is None:
self._closed = True
raise MessagePumpClosed("The message pump is now closed")
@@ -266,8 +268,11 @@ class MessagePump(metaclass=MessagePumpMeta):
self.app.screen._invoke_later(message.callback)
def _close_messages_no_wait(self) -> None:
"""Request the message queue to exit."""
self._message_queue.put_nowait(None)
"""Request the message queue to immediately exit."""
self._message_queue.put_nowait(messages.CloseMessages(sender=self))
async def _on_close_messages(self, message: messages.CloseMessages) -> None:
await self._close_messages()
async def _close_messages(self) -> None:
"""Close message queue, and optionally wait for queue to finish processing."""
@@ -278,6 +283,7 @@ class MessagePump(metaclass=MessagePumpMeta):
for timer in stop_timers:
await timer.stop()
self._timers.clear()
await self._message_queue.put(events.Unmount(sender=self))
await self._message_queue.put(None)
if self._task is not None and asyncio.current_task() != self._task:
# Ensure everything is closed before returning
@@ -285,7 +291,8 @@ class MessagePump(metaclass=MessagePumpMeta):
def _start_messages(self) -> None:
"""Start messages task."""
self._task = asyncio.create_task(self._process_messages())
if self.app._running:
self._task = asyncio.create_task(self._process_messages())
async def _process_messages(self) -> None:
self._running = True
@@ -370,8 +377,6 @@ class MessagePump(metaclass=MessagePumpMeta):
self.app._handle_exception(error)
break
log("CLOSED", self)
async def _dispatch_message(self, message: Message) -> None:
"""Dispatch a message received from the message queue.
@@ -424,6 +429,7 @@ class MessagePump(metaclass=MessagePumpMeta):
handler_name = message._handler_name
# Look through the MRO to find a handler
dispatched = False
for cls, method in self._get_dispatch_methods(handler_name, message):
log.event.verbosity(message.verbose)(
message,
@@ -431,7 +437,10 @@ class MessagePump(metaclass=MessagePumpMeta):
self,
f"method=<{cls.__name__}.{handler_name}>",
)
dispatched = True
await invoke(method, message)
if not dispatched:
log.event.verbosity(message.verbose)(message, ">>>", self, "method=None")
# Bubble messages up the DOM (if enabled on the message)
if message.bubble and self._parent and not message._stop_propagation:

View File

@@ -13,6 +13,16 @@ if TYPE_CHECKING:
from .widget import Widget
@rich.repr.auto
class CloseMessages(Message, verbose=True):
"""Requests message pump to close."""
@rich.repr.auto
class ExitApp(Message, verbose=True):
"""Exit the app."""
@rich.repr.auto
class Update(Message, verbose=True):
def __init__(self, sender: MessagePump, widget: Widget):

55
src/textual/pilot.py Normal file
View File

@@ -0,0 +1,55 @@
from __future__ import annotations
import rich.repr
import asyncio
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .app import App
@rich.repr.auto(angular=True)
class Pilot:
"""Pilot object to drive an app."""
def __init__(self, app: App) -> None:
self._app = app
def __rich_repr__(self) -> rich.repr.Result:
yield "app", self._app
@property
def app(self) -> App:
"""Get a reference to the application.
Returns:
App: The App instance.
"""
return self._app
async def press(self, *keys: str) -> None:
"""Simulate key-presses.
Args:
*key: Keys to press.
"""
if keys:
await self._app._press_keys(keys)
async def pause(self, delay: float = 50 / 1000) -> None:
"""Insert a pause.
Args:
delay (float, optional): Seconds to pause. Defaults to 50ms.
"""
await asyncio.sleep(delay)
async def exit(self, result: object) -> None:
"""Exit the app with the given result.
Args:
result (object): The app result returned by `run` or `run_async`.
"""
self.app.exit(result)

View File

@@ -359,6 +359,20 @@ class Widget(DOMNode):
"""Clear arrangement cache, forcing a new arrange operation."""
self._arrangement = None
def _get_virtual_dom(self) -> Iterable[Widget]:
"""Get widgets not part of the DOM.
Returns:
Iterable[Widget]: An iterable of Widgets.
"""
if self._horizontal_scrollbar is not None:
yield self._horizontal_scrollbar
if self._vertical_scrollbar is not None:
yield self._vertical_scrollbar
if self._scrollbar_corner is not None:
yield self._scrollbar_corner
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> AwaitMount:
"""Mount child widgets (making this widget a container).
@@ -587,6 +601,7 @@ class Widget(DOMNode):
Returns:
ScrollBar: ScrollBar Widget.
"""
from .scrollbar import ScrollBar
if self._horizontal_scrollbar is not None:
@@ -595,13 +610,12 @@ class Widget(DOMNode):
vertical=False, name="horizontal", thickness=self.scrollbar_size_horizontal
)
self._horizontal_scrollbar.display = False
self.app._start_widget(self, scroll_bar)
return scroll_bar
def _refresh_scrollbars(self) -> None:
"""Refresh scrollbar visibility."""
if not self.is_scrollable:
if not self.is_scrollable or not self.container_size:
return
styles = self.styles

View File

@@ -18,8 +18,6 @@ from textual.css.styles import Styles, RenderStyles
from textual.dom import DOMNode
from textual.widget import Widget
from tests.utilities.test_app import AppTest
def test_styles_reset():
styles = Styles()
@@ -206,88 +204,3 @@ def test_widget_style_size_fails_if_data_type_is_not_supported(size_dimension_in
with pytest.raises(StyleValueError):
widget.styles.width = size_dimension_input
@pytest.mark.asyncio
@pytest.mark.parametrize(
"overflow_y,scrollbar_gutter,scrollbar_size,text_length,expected_text_widget_width,expects_vertical_scrollbar",
(
# ------------------------------------------------
# ----- Let's start with `overflow-y: auto`:
# short text: full width, no scrollbar
["auto", "auto", 1, "short_text", 80, False],
# long text: reduced width, scrollbar
["auto", "auto", 1, "long_text", 78, True],
# short text, `scrollbar-gutter: stable`: reduced width, no scrollbar
["auto", "stable", 1, "short_text", 78, False],
# long text, `scrollbar-gutter: stable`: reduced width, scrollbar
["auto", "stable", 1, "long_text", 78, True],
# ------------------------------------------------
# ----- And now let's see the behaviour with `overflow-y: scroll`:
# short text: reduced width, scrollbar
["scroll", "auto", 1, "short_text", 78, True],
# long text: reduced width, scrollbar
["scroll", "auto", 1, "long_text", 78, True],
# short text, `scrollbar-gutter: stable`: reduced width, scrollbar
["scroll", "stable", 1, "short_text", 78, True],
# long text, `scrollbar-gutter: stable`: reduced width, scrollbar
["scroll", "stable", 1, "long_text", 78, True],
# ------------------------------------------------
# ----- Finally, let's check the behaviour with `overflow-y: hidden`:
# short text: full width, no scrollbar
["hidden", "auto", 1, "short_text", 80, False],
# long text: full width, no scrollbar
["hidden", "auto", 1, "long_text", 80, False],
# short text, `scrollbar-gutter: stable`: reduced width, no scrollbar
["hidden", "stable", 1, "short_text", 78, False],
# long text, `scrollbar-gutter: stable`: reduced width, no scrollbar
["hidden", "stable", 1, "long_text", 78, False],
# ------------------------------------------------
# ----- Bonus round with a custom scrollbar size, now that we can set this:
["auto", "auto", 3, "short_text", 80, False],
["auto", "auto", 3, "long_text", 77, True],
["scroll", "auto", 3, "short_text", 77, True],
["scroll", "stable", 3, "short_text", 77, True],
["hidden", "auto", 3, "long_text", 80, False],
["hidden", "stable", 3, "short_text", 77, False],
),
)
async def test_scrollbar_gutter(
overflow_y: str,
scrollbar_gutter: str,
scrollbar_size: int,
text_length: Literal["short_text", "long_text"],
expected_text_widget_width: int,
expects_vertical_scrollbar: bool,
):
from rich.text import Text
from textual.geometry import Size
class TextWidget(Widget):
def render(self) -> Text:
text_multiplier = 10 if text_length == "long_text" else 2
return Text(
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. In velit liber a a a."
* text_multiplier
)
container = Widget()
container.styles.height = 3
container.styles.overflow_y = overflow_y
container.styles.scrollbar_gutter = scrollbar_gutter
if scrollbar_size > 1:
container.styles.scrollbar_size_vertical = scrollbar_size
text_widget = TextWidget()
text_widget.styles.height = "auto"
container._add_child(text_widget)
class MyTestApp(AppTest):
def compose(self) -> ComposeResult:
yield container
app = MyTestApp(test_name="scrollbar_gutter", size=Size(80, 10))
await app.boot_and_shutdown()
assert text_widget.outer_size.width == expected_text_widget_width
assert container.scrollbars_enabled[0] is expects_vertical_scrollbar

View File

@@ -41,7 +41,7 @@ def snap_compare(
def compare(
app_path: str,
press: Iterable[str] = ("_",),
terminal_size: tuple[int, int] = (24, 80),
terminal_size: tuple[int, int] = (80, 24),
) -> bool:
"""
Compare a current screenshot of the app running at app_path, with
@@ -52,14 +52,13 @@ def snap_compare(
Args:
app_path (str): The path of the app.
press (Iterable[str]): Key presses to run before taking screenshot. "_" is a short pause.
terminal_size (tuple[int, int]): A pair of integers (rows, columns), representing terminal size.
terminal_size (tuple[int, int]): A pair of integers (WIDTH, SIZE), representing terminal size.
Returns:
bool: True if the screenshot matches the snapshot.
"""
node = request.node
app = import_app(app_path)
compare.app = app
actual_screenshot = take_svg_screenshot(
app=app,
press=press,
@@ -69,7 +68,9 @@ def snap_compare(
if result is False:
# The split and join below is a mad hack, sorry...
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(str(snapshot).splitlines()[1:-1])
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(
str(snapshot).splitlines()[1:-1]
)
node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot
node.stash[TEXTUAL_APP_KEY] = app
else:
@@ -85,6 +86,7 @@ class SvgSnapshotDiff:
"""Model representing a diff between current screenshot of an app,
and the snapshot on disk. This is ultimately intended to be used in
a Jinja2 template."""
snapshot: Optional[str]
actual: Optional[str]
test_name: str
@@ -119,7 +121,7 @@ def pytest_sessionfinish(
snapshot=str(snapshot_svg),
actual=str(actual_svg),
file_similarity=100
* difflib.SequenceMatcher(
* difflib.SequenceMatcher(
a=str(snapshot_svg), b=str(actual_svg)
).ratio(),
test_name=name,
@@ -176,7 +178,9 @@ def pytest_terminal_summary(
if diffs:
snapshot_report_location = config._textual_snapshot_html_report
console.rule("[b red]Textual Snapshot Report", style="red")
console.print(f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n")
console.print(
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
)
console.print(f"[dim]{snapshot_report_location}\n")
console.rule(style="red")

View File

@@ -8,6 +8,7 @@ from textual.widgets import Input, Button
# --- Layout related stuff ---
def test_grid_layout_basic(snap_compare):
assert snap_compare("docs/examples/guide/layout/grid_layout1.py")
@@ -41,6 +42,7 @@ def test_dock_layout_sidebar(snap_compare):
# When adding a new widget, ideally we should also create a snapshot test
# from these examples which test rendering and simple interactions with it.
def test_checkboxes(snap_compare):
"""Tests checkboxes but also acts a regression test for using
width: auto in a Horizontal layout context."""
@@ -64,22 +66,11 @@ def test_input_and_focus(snap_compare):
]
assert snap_compare("docs/examples/widgets/input.py", press=press)
# Assert that the state of the Input is what we'd expect
app: App = snap_compare.app
input: Input = app.query_one(Input)
assert input.value == "Darren"
assert input.cursor_position == 6
assert input.view_position == 0
def test_buttons_render(snap_compare):
# Testing button rendering. We press tab to focus the first button too.
assert snap_compare("docs/examples/widgets/button.py", press=["tab"])
app = snap_compare.app
button: Button = app.query_one(Button)
assert app.focused is button
def test_datatable_render(snap_compare):
press = ["tab", "down", "down", "right", "up", "left"]
@@ -99,7 +90,9 @@ def test_header_render(snap_compare):
# If any of these change, something has likely broken, so snapshot each of them.
PATHS = [
str(PurePosixPath(path)) for path in Path("docs/examples/styles").iterdir() if path.suffix == ".py"
str(PurePosixPath(path))
for path in Path("docs/examples/styles").iterdir()
if path.suffix == ".py"
]

23
tests/test_auto_pilot.py Normal file
View File

@@ -0,0 +1,23 @@
from textual.app import App
from textual.pilot import Pilot
from textual import events
def test_auto_pilot() -> None:
keys_pressed: list[str] = []
class TestApp(App):
def on_key(self, event: events.Key) -> None:
keys_pressed.append(event.key)
async def auto_pilot(pilot: Pilot) -> None:
await pilot.press("tab", *"foo")
await pilot.pause(1 / 100)
await pilot.exit("bar")
app = TestApp()
result = app.run(headless=True, auto_pilot=auto_pilot)
assert result == "bar"
assert keys_pressed == ["tab", "f", "o", "o"]

View File

@@ -1,6 +1,8 @@
import asyncio
from time import time
from textual.app import App
from textual.pilot import Pilot
class RefreshApp(App[float]):
@@ -22,7 +24,10 @@ class RefreshApp(App[float]):
def test_auto_refresh():
app = RefreshApp()
elapsed = app.run(quit_after=1, headless=True)
async def quit_after(pilot: Pilot) -> None:
await asyncio.sleep(1)
elapsed = app.run(auto_pilot=quit_after, headless=True)
assert elapsed is not None
# CI can run slower, so we need to give this a bit of margin
assert 0.2 <= elapsed < 0.8

View File

@@ -1,116 +0,0 @@
from __future__ import annotations
from typing import Sequence, cast
import pytest
from tests.utilities.test_app import AppTest
from textual.app import ComposeResult
from textual.geometry import Size
from textual.widget import Widget
from textual.widgets import Placeholder
pytestmark = pytest.mark.integration_test
SCREEN_SIZE = Size(100, 30)
@pytest.mark.skip("Needs a rethink")
@pytest.mark.asyncio
@pytest.mark.parametrize(
(
"screen_size",
"placeholders_count",
"scroll_to_placeholder_id",
"scroll_to_animate",
"waiting_duration",
"last_screen_expected_placeholder_ids",
),
(
[SCREEN_SIZE, 10, None, None, 0.01, (0, 1, 2, 3, 4)],
[SCREEN_SIZE, 10, "placeholder_3", False, 0.01, (0, 1, 2, 3, 4)],
[SCREEN_SIZE, 10, "placeholder_5", False, 0.01, (1, 2, 3, 4, 5)],
[SCREEN_SIZE, 10, "placeholder_7", False, 0.01, (3, 4, 5, 6, 7)],
[SCREEN_SIZE, 10, "placeholder_9", False, 0.01, (5, 6, 7, 8, 9)],
# N.B. Scroll duration is hard-coded to 0.2 in the `scroll_to_widget` method atm
# Waiting for this duration should allow us to see the scroll finished:
[SCREEN_SIZE, 10, "placeholder_9", True, 0.21, (5, 6, 7, 8, 9)],
# After having waited for approximately half of the scrolling duration, we should
# see the middle Placeholders as we're scrolling towards the last of them.
[SCREEN_SIZE, 10, "placeholder_9", True, 0.1, (4, 5, 6, 7, 8)],
),
)
async def test_scroll_to_widget(
screen_size: Size,
placeholders_count: int,
scroll_to_animate: bool | None,
scroll_to_placeholder_id: str | None,
waiting_duration: float | None,
last_screen_expected_placeholder_ids: Sequence[int],
):
class VerticalContainer(Widget):
DEFAULT_CSS = """
VerticalContainer {
layout: vertical;
overflow: hidden auto;
}
VerticalContainer Placeholder {
margin: 1 0;
height: 5;
}
"""
class MyTestApp(AppTest):
DEFAULT_CSS = """
Placeholder {
height: 5; /* minimal height to see the name of a Placeholder */
}
"""
def compose(self) -> ComposeResult:
placeholders = [
Placeholder(id=f"placeholder_{i}", name=f"Placeholder #{i}")
for i in range(placeholders_count)
]
yield VerticalContainer(*placeholders, id="root")
app = MyTestApp(size=screen_size, test_name="scroll_to_widget")
async with app.in_running_state(waiting_duration_after_yield=waiting_duration or 0):
if scroll_to_placeholder_id:
target_widget_container = cast(Widget, app.query("#root").first())
target_widget = cast(
Widget, app.query(f"#{scroll_to_placeholder_id}").first()
)
target_widget_container.scroll_to_widget(
target_widget, animate=scroll_to_animate
)
last_display_capture = app.last_display_capture
placeholders_visibility_by_id = {
id_: f"placeholder_{id_}" in last_display_capture
for id_ in range(placeholders_count)
}
print(placeholders_visibility_by_id)
# Let's start by checking placeholders that should be visible:
for placeholder_id in last_screen_expected_placeholder_ids:
assert placeholders_visibility_by_id[placeholder_id] is True, (
f"Placeholder '{placeholder_id}' should be visible but isn't"
f" :: placeholders_visibility_by_id={placeholders_visibility_by_id}"
)
# Ok, now for placeholders that should *not* be visible:
# We're simply going to check that all the placeholders that are not in
# `last_screen_expected_placeholder_ids` are not on the screen:
last_screen_expected_out_of_viewport_placeholder_ids = sorted(
tuple(
set(range(placeholders_count)) - set(last_screen_expected_placeholder_ids)
)
)
for placeholder_id in last_screen_expected_out_of_viewport_placeholder_ids:
assert placeholders_visibility_by_id[placeholder_id] is False, (
f"Placeholder '{placeholder_id}' should not be visible but is"
f" :: placeholders_visibility_by_id={placeholders_visibility_by_id}"
)

View File

@@ -89,4 +89,4 @@ async def test_screens():
screen1.remove()
screen2.remove()
screen3.remove()
await app.shutdown()
await app._shutdown()

21
tests/test_test_runner.py Normal file
View File

@@ -0,0 +1,21 @@
from textual.app import App
from textual import events
async def test_run_test() -> None:
"""Test the run_test context manager."""
keys_pressed: list[str] = []
class TestApp(App[str]):
def on_key(self, event: events.Key) -> None:
keys_pressed.append(event.key)
app = TestApp()
async with app.run_test() as pilot:
assert str(pilot) == "<Pilot app=TestApp(title='TestApp')>"
await pilot.press("tab", *"foo")
await pilot.pause(1 / 100)
await pilot.exit("bar")
assert app.return_value == "bar"
assert keys_pressed == ["tab", "f", "o", "o"]

50
tests/test_unmount.py Normal file
View File

@@ -0,0 +1,50 @@
from textual.app import App, ComposeResult
from textual import events
from textual.containers import Container
from textual.screen import Screen
async def test_unmount():
"""Test unmount events are received in reverse DOM order."""
unmount_ids: list[str] = []
class UnmountWidget(Container):
def on_unmount(self, event: events.Unmount):
unmount_ids.append(f"{self.__class__.__name__}#{self.id}")
class MyScreen(Screen):
def compose(self) -> ComposeResult:
yield UnmountWidget(
UnmountWidget(
UnmountWidget(id="bar1"), UnmountWidget(id="bar2"), id="bar"
),
UnmountWidget(
UnmountWidget(id="baz1"), UnmountWidget(id="baz2"), id="baz"
),
id="top",
)
def on_unmount(self, event: events.Unmount):
unmount_ids.append(f"{self.__class__.__name__}#{self.id}")
class UnmountApp(App):
async def on_mount(self) -> None:
self.push_screen(MyScreen(id="main"))
app = UnmountApp()
async with app.run_test() as pilot:
await pilot.pause() # TODO remove when push_screen is awaitable
await pilot.exit(None)
expected = [
"UnmountWidget#bar1",
"UnmountWidget#bar2",
"UnmountWidget#baz1",
"UnmountWidget#baz2",
"UnmountWidget#bar",
"UnmountWidget#baz",
"UnmountWidget#top",
"MyScreen#main",
]
assert unmount_ids == expected

View File

@@ -1,353 +0,0 @@
from __future__ import annotations
import asyncio
import contextlib
import io
from math import ceil
from pathlib import Path
from time import monotonic
from typing import AsyncContextManager, cast, ContextManager
from unittest import mock
from rich.console import Console
from textual import events, errors
from textual._ansi_sequences import SYNC_START
from textual._clock import _Clock
from textual._context import active_app
from textual.app import App, ComposeResult
from textual.app import WINDOWS
from textual.driver import Driver
from textual.geometry import Size, Region
# N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc,
# but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/
class AppTest(App):
def __init__(self, *, test_name: str, size: Size):
# Tests will log in "/tests/test.[test name].log":
log_path = Path(__file__).parent.parent / f"test.{test_name}.log"
super().__init__(
driver_class=DriverTest,
)
# Let's disable all features by default
self.features = frozenset()
# We need this so the "start buffeting"` is always sent for a screen refresh,
# whatever the environment:
# (we use it to slice the output into distinct full screens displays)
self._sync_available = True
self._size = size
self._console = ConsoleTest(width=size.width, height=size.height)
self._error_console = ConsoleTest(width=size.width, height=size.height)
def log_tree(self) -> None:
"""Handy shortcut when testing stuff"""
self.log(self.tree)
def compose(self) -> ComposeResult:
raise NotImplementedError(
"Create a subclass of TestApp and override its `compose()` method, rather than using TestApp directly"
)
def in_running_state(
self,
*,
time_mocking_ticks_granularity_fps: int = 60, # i.e. when moving forward by 1 second we'll do it though 60 ticks
waiting_duration_after_initialisation: float = 1,
waiting_duration_after_yield: float = 0,
) -> AsyncContextManager[ClockMock]:
async def run_app() -> None:
await self._process_messages()
@contextlib.asynccontextmanager
async def get_running_state_context_manager():
with mock_textual_timers(
ticks_granularity_fps=time_mocking_ticks_granularity_fps
) as clock_mock:
run_task = asyncio.create_task(run_app())
# We have to do this because `run_app()` is running in its own async task, and our test is going to
# run in this one - so the app must also be the active App in our current context:
self._set_active()
await clock_mock.advance_clock(waiting_duration_after_initialisation)
# make sure the App has entered its main loop at this stage:
assert self._driver is not None
await self.force_full_screen_update()
# And now it's time to pass the torch on to the test function!
# We provide the `move_clock_forward` function to it,
# so it can also do some time-based Textual stuff if it needs to:
yield clock_mock
await clock_mock.advance_clock(waiting_duration_after_yield)
# Make sure our screen is up-to-date before exiting the context manager,
# so tests using our `last_display_capture` for example can assert things on a fully refreshed screen:
await self.force_full_screen_update()
# End of simulated time: we just shut down ourselves:
assert not run_task.done()
await self.shutdown()
await run_task
return get_running_state_context_manager()
async def boot_and_shutdown(
self,
*,
waiting_duration_after_initialisation: float = 0.001,
waiting_duration_before_shutdown: float = 0,
):
"""Just a commodity shortcut for `async with app.in_running_state(): pass`, for simple cases"""
async with self.in_running_state(
waiting_duration_after_initialisation=waiting_duration_after_initialisation,
waiting_duration_after_yield=waiting_duration_before_shutdown,
):
pass
def get_char_at(self, x: int, y: int) -> str:
"""Get the character at the given cell or empty string
Args:
x (int): X position within the Layout
y (int): Y position within the Layout
Returns:
str: The character at the cell (x, y) within the Layout
"""
# N.B. Basically a copy-paste-and-slightly-adapt of `Compositor.get_style_at()`
try:
widget, region = self.get_widget_at(x, y)
except errors.NoWidget:
return ""
if widget not in self.screen._compositor.visible_widgets:
return ""
x -= region.x
y -= region.y
lines = widget.render_lines(Region(0, y, region.width, 1))
if not lines:
return ""
end = 0
for segment in lines[0]:
end += segment.cell_length
if x < end:
return segment.text[0]
return ""
async def force_full_screen_update(
self, *, repaint: bool = True, layout: bool = True
) -> None:
try:
screen = self.screen
except IndexError:
return # the app may not have a screen yet
# We artificially tell the Compositor that the whole area should be refreshed
screen._compositor._dirty_regions = {
Region(0, 0, screen.outer_size.width, screen.outer_size.height),
}
screen.refresh(repaint=repaint, layout=layout)
# We also have to make sure we have at least one dirty widget, or `screen._on_update()` will early return:
screen._dirty_widgets.add(screen)
screen._on_timer_update()
await let_asyncio_process_some_events()
def _handle_exception(self, error: Exception) -> None:
# In tests we want the errors to be raised, rather than printed to a Console
raise error
def run(self):
raise NotImplementedError(
"Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`"
)
@property
def active_app(self) -> App | None:
return active_app.get()
@property
def total_capture(self) -> str | None:
return self.console.file.getvalue()
@property
def last_display_capture(self) -> str | None:
total_capture = self.total_capture
if not total_capture:
return None
screen_captures = total_capture.split(SYNC_START)
for single_screen_capture in reversed(screen_captures):
if len(single_screen_capture) > 30:
# let's return the last occurrence of a screen that seem to be properly "fully-paint"
return single_screen_capture
return None
@property
def console(self) -> ConsoleTest:
return self._console
@console.setter
def console(self, console: Console) -> None:
"""This is a no-op, the console is always a TestConsole"""
return
@property
def error_console(self) -> ConsoleTest:
return self._error_console
@error_console.setter
def error_console(self, console: Console) -> None:
"""This is a no-op, the error console is always a TestConsole"""
return
class ConsoleTest(Console):
def __init__(self, *, width: int, height: int):
file = io.StringIO()
super().__init__(
color_system="256",
file=file,
width=width,
height=height,
force_terminal=False,
legacy_windows=False,
)
@property
def file(self) -> io.StringIO:
return cast(io.StringIO, self._file)
@property
def is_dumb_terminal(self) -> bool:
return False
class DriverTest(Driver):
def start_application_mode(self) -> None:
size = Size(self.console.size.width, self.console.size.height)
event = events.Resize(self._target, size, size)
asyncio.run_coroutine_threadsafe(
self._target.post_message(event),
loop=asyncio.get_running_loop(),
)
def disable_input(self) -> None:
pass
def stop_application_mode(self) -> None:
pass
# It seems that we have to give _way more_ time to `asyncio` on Windows in order to see our different awaiters
# properly triggered when we pause our own "move clock forward" loop.
# It could be caused by the fact that the time resolution for `asyncio` on this platform seems rather low:
# > The resolution of the monotonic clock on Windows is usually around 15.6 msec.
# > The best resolution is 0.5 msec.
# @link https://docs.python.org/3/library/asyncio-platforms.html:
ASYNCIO_EVENTS_PROCESSING_REQUIRED_PERIOD = 0.025 if WINDOWS else 0.005
async def let_asyncio_process_some_events() -> None:
await asyncio.sleep(ASYNCIO_EVENTS_PROCESSING_REQUIRED_PERIOD)
class ClockMock(_Clock):
# To avoid issues with floats we will store the current time as an integer internally.
# Tenths of microseconds should be a good enough granularity:
TIME_RESOLUTION = 10_000_000
def __init__(
self,
*,
ticks_granularity_fps: int = 60,
):
self._ticks_granularity_fps = ticks_granularity_fps
self._single_tick_duration = int(self.TIME_RESOLUTION / ticks_granularity_fps)
self._start_time: int = -1
self._current_time: int = -1
# For each call to our `sleep` method we will store an asyncio.Event
# and the time at which we should trigger it:
self._pending_sleep_events: dict[int, list[asyncio.Event]] = {}
def get_time_no_wait(self) -> float:
if self._current_time == -1:
self._start_clock()
return self._current_time / self.TIME_RESOLUTION
async def sleep(self, seconds: float) -> None:
event = asyncio.Event()
internal_waiting_duration = int(seconds * self.TIME_RESOLUTION)
target_event_monotonic_time = self._current_time + internal_waiting_duration
self._pending_sleep_events.setdefault(target_event_monotonic_time, []).append(
event
)
# Ok, let's wait for this Event
# (which can only be "unlocked" by calls to `advance_clock()`)
await event.wait()
async def advance_clock(self, seconds: float) -> None:
"""
Artificially advance the Textual clock forward.
Args:
seconds: for each second we will artificially tick `ticks_granularity_fps` times
"""
if self._current_time == -1:
self._start_clock()
ticks_count = ceil(seconds * self._ticks_granularity_fps)
activated_timers_count_total = 0 # useful when debugging this code :-)
for tick_counter in range(ticks_count):
self._current_time += self._single_tick_duration
activated_timers_count = self._check_sleep_timers_to_activate()
activated_timers_count_total += activated_timers_count
# Now that we likely unlocked some occurrences of `await sleep(duration)`,
# let's give an opportunity to asyncio-related stuff to happen:
if activated_timers_count:
await let_asyncio_process_some_events()
await let_asyncio_process_some_events()
def _start_clock(self) -> None:
# N.B. `start_time` is not actually used, but it is useful to have when we set breakpoints there :-)
self._start_time = self._current_time = int(monotonic() * self.TIME_RESOLUTION)
def _check_sleep_timers_to_activate(self) -> int:
activated_timers_count = 0
activated_events_times_to_clear: list[int] = []
for (monotonic_time, target_events) in self._pending_sleep_events.items():
if self._current_time < monotonic_time:
continue # not time for you yet, dear awaiter...
# Right, let's release these waiting events!
for event in target_events:
event.set()
activated_timers_count += len(target_events)
# ...and let's mark it for removal:
activated_events_times_to_clear.append(monotonic_time)
for event_time_to_clear in activated_events_times_to_clear:
del self._pending_sleep_events[event_time_to_clear]
return activated_timers_count
def mock_textual_timers(
*,
ticks_granularity_fps: int = 60,
) -> ContextManager[ClockMock]:
@contextlib.contextmanager
def mock_textual_timers_context_manager():
clock_mock = ClockMock(ticks_granularity_fps=ticks_granularity_fps)
with mock.patch("textual._clock._clock", new=clock_mock):
yield clock_mock
return mock_textual_timers_context_manager()