This commit is contained in:
Will McGugan
2022-10-29 11:44:31 +01:00
parent 264b4fe733
commit 2afb00f5b3
5 changed files with 35 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from asyncio import Task
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import inspect import inspect
import io import io
@@ -15,7 +16,6 @@ from pathlib import Path, PurePath
from time import perf_counter from time import perf_counter
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
Coroutine, Coroutine,
Generic, Generic,
@@ -654,7 +654,7 @@ class App(Generic[ReturnType], DOMNode):
await app._process_messages(ready_callback=on_app_ready, headless=headless) await app._process_messages(ready_callback=on_app_ready, headless=headless)
# Launch the app in the "background" # Launch the app in the "background"
asyncio.create_task(run_app(app)) app_task = asyncio.create_task(run_app(app))
# Wait until the app has performed all startup routines. # Wait until the app has performed all startup routines.
await app_ready_event.wait() await app_ready_event.wait()
@@ -664,12 +664,13 @@ class App(Generic[ReturnType], DOMNode):
# Shutdown the app cleanly # Shutdown the app cleanly
await app._shutdown() await app._shutdown()
await app_task
async def run_async( async def run_async(
self, self,
*, *,
headless: bool = False, headless: bool = False,
auto_pilot: AutopilotCallbackType, auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None: ) -> ReturnType | None:
"""Run the app asynchronously. """Run the app asynchronously.
@@ -684,8 +685,11 @@ class App(Generic[ReturnType], DOMNode):
app = self app = self
auto_pilot_task: Task | None = None
async def app_ready() -> None: async def app_ready() -> None:
"""Called by the message loop when the app is ready.""" """Called by the message loop when the app is ready."""
nonlocal auto_pilot_task
if auto_pilot is not None: if auto_pilot is not None:
async def run_auto_pilot(pilot) -> None: async def run_auto_pilot(pilot) -> None:
@@ -696,17 +700,25 @@ class App(Generic[ReturnType], DOMNode):
raise raise
pilot = Pilot(app) pilot = Pilot(app)
asyncio.create_task(run_auto_pilot(pilot)) auto_pilot_task = asyncio.create_task(run_auto_pilot(pilot))
try:
await app._process_messages(
ready_callback=None if auto_pilot is None else app_ready,
headless=headless,
)
finally:
if auto_pilot_task is not None:
await auto_pilot_task
await app._shutdown()
await app._process_messages(ready_callback=app_ready, headless=headless)
await app._shutdown()
return app.return_value return app.return_value
def run( def run(
self, self,
*, *,
headless: bool = False, headless: bool = False,
auto_pilot: AutopilotCallbackType, auto_pilot: AutopilotCallbackType | None = None,
) -> ReturnType | None: ) -> ReturnType | None:
"""Run the app. """Run the app.
@@ -1287,8 +1299,10 @@ class App(Generic[ReturnType], DOMNode):
parent (Widget): The parent of the Widget. parent (Widget): The parent of the Widget.
widget (Widget): The Widget to start. widget (Widget): The Widget to start.
""" """
widget._attach(parent) widget._attach(parent)
widget._start_messages() widget._start_messages()
self.app._registry.add(widget)
def is_mounted(self, widget: Widget) -> bool: def is_mounted(self, widget: Widget) -> bool:
"""Check if a widget is mounted. """Check if a widget is mounted.
@@ -1321,6 +1335,7 @@ class App(Generic[ReturnType], DOMNode):
async def _shutdown(self) -> None: async def _shutdown(self) -> None:
driver = self._driver driver = self._driver
self._running = False
if driver is not None: if driver is not None:
driver.disable_input() driver.disable_input()
await self._close_all() await self._close_all()
@@ -1328,7 +1343,6 @@ class App(Generic[ReturnType], DOMNode):
await self._dispatch_message(events.UnMount(sender=self)) await self._dispatch_message(events.UnMount(sender=self))
self._running = False
self._print_error_renderables() self._print_error_renderables()
if self.devtools is not None and self.devtools.is_connected: if self.devtools is not None and self.devtools.is_connected:
await self._disconnect_devtools() await self._disconnect_devtools()

View File

@@ -155,7 +155,9 @@ class MessagePump(metaclass=MessagePumpMeta):
return self._pending_message return self._pending_message
finally: finally:
self._pending_message = None self._pending_message = None
message = await self._message_queue.get() message = await self._message_queue.get()
if message is None: if message is None:
self._closed = True self._closed = True
raise MessagePumpClosed("The message pump is now closed") raise MessagePumpClosed("The message pump is now closed")
@@ -289,7 +291,8 @@ class MessagePump(metaclass=MessagePumpMeta):
def _start_messages(self) -> None: def _start_messages(self) -> None:
"""Start messages task.""" """Start messages task."""
self._task = asyncio.create_task(self._process_messages()) if self.app._running:
self._task = asyncio.create_task(self._process_messages())
async def _process_messages(self) -> None: async def _process_messages(self) -> None:
self._running = True self._running = True

View File

@@ -595,7 +595,6 @@ class Widget(DOMNode):
vertical=False, name="horizontal", thickness=self.scrollbar_size_horizontal vertical=False, name="horizontal", thickness=self.scrollbar_size_horizontal
) )
self._horizontal_scrollbar.display = False self._horizontal_scrollbar.display = False
self.app._start_widget(self, scroll_bar) self.app._start_widget(self, scroll_bar)
return scroll_bar return scroll_bar

View File

@@ -59,7 +59,6 @@ def snap_compare(
""" """
node = request.node node = request.node
app = import_app(app_path) app = import_app(app_path)
compare.app = app
actual_screenshot = take_svg_screenshot( actual_screenshot = take_svg_screenshot(
app=app, app=app,
press=press, press=press,
@@ -69,7 +68,9 @@ def snap_compare(
if result is False: if result is False:
# The split and join below is a mad hack, sorry... # The split and join below is a mad hack, sorry...
node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(str(snapshot).splitlines()[1:-1]) node.stash[TEXTUAL_SNAPSHOT_SVG_KEY] = "\n".join(
str(snapshot).splitlines()[1:-1]
)
node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot node.stash[TEXTUAL_ACTUAL_SVG_KEY] = actual_screenshot
node.stash[TEXTUAL_APP_KEY] = app node.stash[TEXTUAL_APP_KEY] = app
else: else:
@@ -85,6 +86,7 @@ class SvgSnapshotDiff:
"""Model representing a diff between current screenshot of an app, """Model representing a diff between current screenshot of an app,
and the snapshot on disk. This is ultimately intended to be used in and the snapshot on disk. This is ultimately intended to be used in
a Jinja2 template.""" a Jinja2 template."""
snapshot: Optional[str] snapshot: Optional[str]
actual: Optional[str] actual: Optional[str]
test_name: str test_name: str
@@ -119,7 +121,7 @@ def pytest_sessionfinish(
snapshot=str(snapshot_svg), snapshot=str(snapshot_svg),
actual=str(actual_svg), actual=str(actual_svg),
file_similarity=100 file_similarity=100
* difflib.SequenceMatcher( * difflib.SequenceMatcher(
a=str(snapshot_svg), b=str(actual_svg) a=str(snapshot_svg), b=str(actual_svg)
).ratio(), ).ratio(),
test_name=name, test_name=name,
@@ -176,7 +178,9 @@ def pytest_terminal_summary(
if diffs: if diffs:
snapshot_report_location = config._textual_snapshot_html_report snapshot_report_location = config._textual_snapshot_html_report
console.rule("[b red]Textual Snapshot Report", style="red") console.rule("[b red]Textual Snapshot Report", style="red")
console.print(f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n" console.print(
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n") f"\n[black on red]{len(diffs)} mismatched snapshots[/]\n"
f"\n[b]View the [link=file://{snapshot_report_location}]failure report[/].\n"
)
console.print(f"[dim]{snapshot_report_location}\n") console.print(f"[dim]{snapshot_report_location}\n")
console.rule(style="red") console.rule(style="red")

View File

@@ -66,22 +66,11 @@ def test_input_and_focus(snap_compare):
] ]
assert snap_compare("docs/examples/widgets/input.py", press=press) assert snap_compare("docs/examples/widgets/input.py", press=press)
# Assert that the state of the Input is what we'd expect
# app: App = snap_compare.app
# input: Input = app.query_one(Input)
# assert input.value == "Darren"
# assert input.cursor_position == 6
# assert input.view_position == 0
def test_buttons_render(snap_compare): def test_buttons_render(snap_compare):
# Testing button rendering. We press tab to focus the first button too. # Testing button rendering. We press tab to focus the first button too.
assert snap_compare("docs/examples/widgets/button.py", press=["tab"]) assert snap_compare("docs/examples/widgets/button.py", press=["tab"])
# app = snap_compare.app
# button: Button = app.query_one(Button)
# assert app.focused is button
def test_datatable_render(snap_compare): def test_datatable_render(snap_compare):
press = ["tab", "down", "down", "right", "up", "left"] press = ["tab", "down", "down", "right", "up", "left"]