This commit is contained in:
Will McGugan
2022-10-29 11:44:31 +01:00
parent 264b4fe733
commit 2afb00f5b3
5 changed files with 35 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
from asyncio import Task
from contextlib import asynccontextmanager
import inspect
import io
@@ -15,7 +16,6 @@ from pathlib import Path, PurePath
from time import perf_counter
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Generic,
@@ -654,7 +654,7 @@ class App(Generic[ReturnType], DOMNode):
await app._process_messages(ready_callback=on_app_ready, headless=headless)
# Launch the app in the "background"
asyncio.create_task(run_app(app))
app_task = asyncio.create_task(run_app(app))
# Wait until the app has performed all startup routines.
await app_ready_event.wait()
@@ -664,12 +664,13 @@ class App(Generic[ReturnType], DOMNode):
# Shutdown the app cleanly
await app._shutdown()
await app_task
async def run_async(
self,
*,
headless: bool = False,
auto_pilot: AutopilotCallbackType,
auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None:
"""Run the app asynchronously.
@@ -684,8 +685,11 @@ class App(Generic[ReturnType], DOMNode):
app = self
auto_pilot_task: Task | None = None
async def app_ready() -> None:
"""Called by the message loop when the app is ready."""
nonlocal auto_pilot_task
if auto_pilot is not None:
async def run_auto_pilot(pilot) -> None:
@@ -696,17 +700,25 @@ class App(Generic[ReturnType], DOMNode):
raise
pilot = Pilot(app)
asyncio.create_task(run_auto_pilot(pilot))
auto_pilot_task = asyncio.create_task(run_auto_pilot(pilot))
try:
await app._process_messages(
ready_callback=None if auto_pilot is None else app_ready,
headless=headless,
)
finally:
if auto_pilot_task is not None:
await auto_pilot_task
await app._shutdown()
await app._process_messages(ready_callback=app_ready, headless=headless)
await app._shutdown()
return app.return_value
def run(
self,
*,
headless: bool = False,
auto_pilot: AutopilotCallbackType,
auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None:
"""Run the app.
@@ -1287,8 +1299,10 @@ class App(Generic[ReturnType], DOMNode):
parent (Widget): The parent of the Widget.
widget (Widget): The Widget to start.
"""
widget._attach(parent)
widget._start_messages()
self.app._registry.add(widget)
def is_mounted(self, widget: Widget) -> bool:
"""Check if a widget is mounted.
@@ -1321,6 +1335,7 @@ class App(Generic[ReturnType], DOMNode):
async def _shutdown(self) -> None:
driver = self._driver
self._running = False
if driver is not None:
driver.disable_input()
await self._close_all()
@@ -1328,7 +1343,6 @@ class App(Generic[ReturnType], DOMNode):
await self._dispatch_message(events.UnMount(sender=self))
self._running = False
self._print_error_renderables()
if self.devtools is not None and self.devtools.is_connected:
await self._disconnect_devtools()

View File

@@ -155,7 +155,9 @@ class MessagePump(metaclass=MessagePumpMeta):
return self._pending_message
finally:
self._pending_message = None
message = await self._message_queue.get()
if message is None:
self._closed = True
raise MessagePumpClosed("The message pump is now closed")
@@ -289,7 +291,8 @@ class MessagePump(metaclass=MessagePumpMeta):
def _start_messages(self) -> None:
"""Start messages task."""
self._task = asyncio.create_task(self._process_messages())
if self.app._running:
self._task = asyncio.create_task(self._process_messages())
async def _process_messages(self) -> None:
self._running = True

View File

@@ -595,7 +595,6 @@ class Widget(DOMNode):
vertical=False, name="horizontal", thickness=self.scrollbar_size_horizontal
)
self._horizontal_scrollbar.display = False
self.app._start_widget(self, scroll_bar)
return scroll_bar

View File

@@ -59,7 +59,6 @@ def snap_compare(
"""
node = request.node
app = import_app(app_path)
compare.app = app
actual_screenshot = take_svg_screenshot(
app=app,
press=press,
@@ -69,7 +68,9 @@ def snap_compare(
if result is False:
# The split and join below is a mad hack, sorry...
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(str(snapshot).splitlines()[1:-1])
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(
str(snapshot).splitlines()[1:-1]
)
node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot
node.stash[TEXTUAL_APP_KEY] = app
else:
@@ -85,6 +86,7 @@ class SvgSnapshotDiff:
"""Model representing a diff between current screenshot of an app,
and the snapshot on disk. This is ultimately intended to be used in
a Jinja2 template."""
snapshot: Optional[str]
actual: Optional[str]
test_name: str
@@ -119,7 +121,7 @@ def pytest_sessionfinish(
snapshot=str(snapshot_svg),
actual=str(actual_svg),
file_similarity=100
* difflib.SequenceMatcher(
* difflib.SequenceMatcher(
a=str(snapshot_svg), b=str(actual_svg)
).ratio(),
test_name=name,
@@ -176,7 +178,9 @@ def pytest_terminal_summary(
if diffs:
snapshot_report_location = config._textual_snapshot_html_report
console.rule("[b red]Textual Snapshot Report", style="red")
console.print(f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n")
console.print(
f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
)
console.print(f"[dim]{snapshot_report_location}\n")
console.rule(style="red")

View File

@@ -66,22 +66,11 @@ def test_input_and_focus(snap_compare):
]
assert snap_compare("docs/examples/widgets/input.py", press=press)
# Assert that the state of the Input is what we'd expect
# app: App = snap_compare.app
# input: Input = app.query_one(Input)
# assert input.value == "Darren"
# assert input.cursor_position == 6
# assert input.view_position == 0
def test_buttons_render(snap_compare):
# Testing button rendering. We press tab to focus the first button too.
assert snap_compare("docs/examples/widgets/button.py", press=["tab"])
# app = snap_compare.app
# button: Button = app.query_one(Button)
# assert app.focused is button
def test_datatable_render(snap_compare):
press = ["tab", "down", "down", "right", "up", "left"]