From 56488945609b8b29d3a4920ef426e9b3b0a14dfa Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 6 Apr 2022 18:00:32 +0100 Subject: [PATCH] Additional testing for devtools client/server --- src/textual/app.py | 18 ++++--- src/textual/devtools.py | 55 +++++++++++++-------- src/textual/devtools_client.py | 3 +- tests/conftest.py | 24 ++++++++++ tests/test_devtools.py | 88 ++++++++++++++++++++++++++++++++++ tests/test_devtools_client.py | 46 ++++++------------ tests/utilities/render.py | 33 ++++++++++++- 7 files changed, 207 insertions(+), 60 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_devtools.py diff --git a/src/textual/app.py b/src/textual/app.py index 1e180bf49..31d7899a8 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -30,7 +30,7 @@ from ._event_broker import extract_handler_actions, NoHandler from ._profile import timer from .binding import Bindings, NoBinding from .css.stylesheet import Stylesheet, StylesheetParseError, StylesheetError -from .devtools_client import DevtoolsClient +from .devtools_client import DevtoolsClient, DevtoolsConnectionError from .dom import DOMNode from .driver import Driver from .file_monitor import FileMonitor @@ -415,10 +415,15 @@ class App(DOMNode): log(f"driver={self.driver_class}") if os.getenv("TEXTUAL_DEVTOOLS") == "1": - await self.devtools.connect() - self.log_file.write(f"Connected to devtools ({self.devtools.url})\n") - self.log_file.flush() - + try: + await self.devtools.connect() + self.log_file.write(f"Connected to devtools ({self.devtools.url})\n") + self.log_file.flush() + except DevtoolsConnectionError: + self.log_file.write( + f"Couldn't connect to devtools ({self.devtools.url})\n" + ) + self.log_file.flush() try: if self.css_file is not None: self.stylesheet.read(self.css_file) @@ -456,7 +461,8 @@ class App(DOMNode): await self.animator.start() await super().process_messages() log("PROCESS END") - await self._disconnect_devtools() + if self.devtools.is_connected: + await self._disconnect_devtools() with timer("animator.stop()"): await self.animator.stop() with timer("self.close_all()"): diff --git a/src/textual/devtools.py b/src/textual/devtools.py index 39aaff51b..a08bf8574 100644 --- a/src/textual/devtools.py +++ b/src/textual/devtools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import json @@ -20,12 +22,13 @@ from aiohttp.web_ws import WebSocketResponse from dateutil.tz import tz from rich.align import Align from rich.console import Console, ConsoleOptions, RenderResult +from rich.markup import escape from rich.rule import Rule from rich.segment import Segments, Segment from rich.table import Table DEFAULT_PORT = 8081 -SIZE_CHANGE_POLL_DELAY_SECONDS = 2 +DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2 QUEUEABLE_TYPES = {"client_log", "client_spillover"} @@ -82,7 +85,9 @@ class DevtoolsInternalMessage: yield Rule(self.message, style=level_to_style.get(self.level, "dim")) -async def enqueue_size_changes(console: Console, outgoing_queue: Queue): +async def _enqueue_size_changes( + console: Console, outgoing_queue: Queue, poll_delay: int +): current_width = console.width current_height = console.height while True: @@ -90,13 +95,13 @@ async def enqueue_size_changes(console: Console, outgoing_queue: Queue): height = console.height dimensions_changed = width != current_width or height != current_height if dimensions_changed: - await enqueue_server_info(outgoing_queue, width, height) + await _enqueue_server_info(outgoing_queue, width, height) current_width = width current_height = height - await asyncio.sleep(SIZE_CHANGE_POLL_DELAY_SECONDS) + await asyncio.sleep(poll_delay) -async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None: +async def _enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None: await outgoing_queue.put( { "type": "server_info", @@ -108,7 +113,7 @@ async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> ) -async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None: +async def _consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None: while True: message_json = await incoming_queue.get() type = message_json["type"] @@ -137,7 +142,7 @@ async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> Non incoming_queue.task_done() -async def consume_outgoing( +async def _consume_outgoing( outgoing_queue: Queue[dict], websocket: WebSocketResponse ) -> None: while True: @@ -153,21 +158,28 @@ async def websocket_handler(request: Request): request.app["websockets"].add(websocket) console = request.app["console"] + size_change_poll_delay = request.app["size_change_poll_delay_secs"] incoming_queue: Queue[dict] = Queue() outgoing_queue: Queue[dict] = Queue() request.app["tasks"].extend( ( - asyncio.create_task(consume_outgoing(outgoing_queue, websocket)), - asyncio.create_task(enqueue_size_changes(console, outgoing_queue)), - asyncio.create_task(consume_incoming(console, incoming_queue)), + asyncio.create_task(_consume_outgoing(outgoing_queue, websocket)), + asyncio.create_task( + _enqueue_size_changes( + console, outgoing_queue, poll_delay=size_change_poll_delay + ) + ), + asyncio.create_task(_consume_incoming(console, incoming_queue)), ) ) - console.print(DevtoolsInternalMessage(f"Client '{request.remote}' connected")) + console.print( + DevtoolsInternalMessage(f"Client '{escape(request.remote)}' connected") + ) - await enqueue_server_info( + await _enqueue_server_info( outgoing_queue, width=console.width, height=console.height ) try: @@ -177,7 +189,7 @@ async def websocket_handler(request: Request): try: message_json = json.loads(message.data) except JSONDecodeError: - console.print(f"{message.data}") + console.print(escape(str(message.data))) continue type = message_json.get("type") @@ -196,13 +208,13 @@ async def websocket_handler(request: Request): request.app["websockets"].discard(websocket) console.print() console.print( - DevtoolsInternalMessage(f"Client '{request.remote}' disconnected") + DevtoolsInternalMessage(f"Client '{escape(request.remote)}' disconnected") ) return websocket -async def on_shutdown(app: Application) -> None: +async def _on_shutdown(app: Application) -> None: for task in app["tasks"]: task.cancel() with suppress(CancelledError): @@ -214,13 +226,16 @@ async def on_shutdown(app: Application) -> None: ) -def run_devtools(port: int) -> None: - app = make_aiohttp_app() +def _run_devtools(port: int) -> None: + app = _make_devtools_aiohttp_app() run_app(app, port=port) -def make_aiohttp_app(): +def _make_devtools_aiohttp_app( + size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS, +): app = Application() + app["size_change_poll_delay_secs"] = size_change_poll_delay_secs app["console"] = Console() app["websockets"] = weakref.WeakSet() app["tasks"] = [] @@ -229,7 +244,7 @@ def make_aiohttp_app(): get("/textual-devtools-websocket", websocket_handler), ] ) - app.on_shutdown.append(on_shutdown) + app.on_shutdown.append(_on_shutdown) return app @@ -238,4 +253,4 @@ if __name__ == "__main__": port = int(sys.argv[1]) else: port = DEFAULT_PORT - run_devtools(port) + _run_devtools(port) diff --git a/src/textual/devtools_client.py b/src/textual/devtools_client.py index fcb9581c2..2deeea509 100644 --- a/src/textual/devtools_client.py +++ b/src/textual/devtools_client.py @@ -44,6 +44,7 @@ class DevtoolsConnectionError(Exception): class DevtoolsClient: def __init__(self, address: str = "127.0.0.1", port: int = DEFAULT_PORT): self.url: str = f"ws://{address}:{port}" + self.session: aiohttp.ClientSession | None = None self.log_queue_task: Task | None = None self.update_console_task: Task | None = None self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO()) @@ -52,7 +53,7 @@ class DevtoolsClient: self.spillover: int = 0 async def connect(self) -> None: - self.session: aiohttp.ClientSession = aiohttp.ClientSession() + self.session = aiohttp.ClientSession() self.log_queue: Queue[str | Type[DetachDevtools]] = Queue( maxsize=LOG_QUEUE_MAXSIZE ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c9eb6fa3d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest + +from textual.devtools import _make_devtools_aiohttp_app +from textual.devtools_client import DevtoolsClient + + +@pytest.fixture +async def server(aiohttp_server, unused_tcp_port): + app = _make_devtools_aiohttp_app( + size_change_poll_delay_secs=0.001, + ) + server = await aiohttp_server(app, port=unused_tcp_port) + yield server + await server.close() + + +@pytest.fixture +async def devtools(aiohttp_client, server): + client = await aiohttp_client(server) + devtools = DevtoolsClient(address=client.host, port=client.port) + await devtools.connect() + yield devtools + await devtools.disconnect() + await client.close() diff --git a/tests/test_devtools.py b/tests/test_devtools.py new file mode 100644 index 000000000..82f08839d --- /dev/null +++ b/tests/test_devtools.py @@ -0,0 +1,88 @@ +import asyncio + +import pytest +import time_machine +from rich.align import Align +from rich.console import Console +from rich.segment import Segment + +from tests.utilities.render import wait_for_predicate +from textual.devtools import DevtoolsLogMessage, DevtoolsInternalMessage + +TIMESTAMP = 1649166819 +WIDTH = 40 +# The string "Hello, world!" is encoded in the payload below +EXAMPLE_LOG = { + "type": "client_log", + "payload": { + "encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ" + "21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=", + "line_number": 123, + "path": "abc/hello.py", + "timestamp": TIMESTAMP, + }, +} + + +@pytest.fixture(scope="module") +def console(): + return Console(width=WIDTH) + + +@time_machine.travel(TIMESTAMP) +def test_log_message_render(console): + message = DevtoolsLogMessage( + [Segment("content")], + path="abc/hello.py", + line_number=123, + unix_timestamp=TIMESTAMP, + ) + table = next(iter(message.__rich_console__(console, console.options))) + + assert len(table.rows) == 1 + + columns = list(table.columns) + left_cells = list(columns[0].cells) + left = left_cells[0] + right_cells = list(columns[1].cells) + right: Align = right_cells[0] + + assert left == " [#888177]15:53:39 [dim]BST[/]" + assert right.align == "right" + assert "hello.py:123" in right.renderable + + +def test_internal_message_render(console): + message = DevtoolsInternalMessage("hello") + rule = next(iter(message.__rich_console__(console, console.options))) + assert rule.title == "hello" + assert rule.characters == "─" + + +async def test_devtools_valid_client_log(devtools): + await devtools.websocket.send_json(EXAMPLE_LOG) + assert devtools.is_connected + + +async def test_devtools_string_not_json_message(devtools): + await devtools.websocket.send_str("ABCDEFG") + assert devtools.is_connected + + +async def test_devtools_invalid_json_message(devtools): + await devtools.websocket.send_json({"invalid": "json"}) + assert devtools.is_connected + + +async def test_devtools_spillover_message(devtools): + await devtools.websocket.send_json( + {"type": "client_spillover", "payload": {"spillover": 123}} + ) + assert devtools.is_connected + + +async def test_devtools_console_size_change(server, devtools): + # Update the width of the console on the server-side + server.app["console"].width = 124 + # Wait for the client side to update the console on their end + await wait_for_predicate(lambda: devtools.console.width == 124) diff --git a/tests/test_devtools_client.py b/tests/test_devtools_client.py index 5b0f1759b..10871c53c 100644 --- a/tests/test_devtools_client.py +++ b/tests/test_devtools_client.py @@ -9,29 +9,12 @@ from aiohttp.web_ws import WebSocketResponse from rich.console import ConsoleDimensions from rich.panel import Panel -from textual.devtools import make_aiohttp_app +from tests.utilities.render import wait_for_predicate from textual.devtools_client import DevtoolsClient TIMESTAMP = 1649166819 -@pytest.fixture -async def server(aiohttp_server, unused_tcp_port): - server = await aiohttp_server(make_aiohttp_app(), port=unused_tcp_port) - yield server - await server.close() - - -@pytest.fixture -async def devtools(aiohttp_client, server): - client = await aiohttp_client(server) - devtools = DevtoolsClient(address=client.host, port=client.port) - await devtools.connect() - yield devtools - await devtools.disconnect() - await client.close() - - def test_devtools_client_initialize_defaults(): devtools = DevtoolsClient() assert devtools.url == "ws://127.0.0.1:8081" @@ -50,7 +33,7 @@ async def test_devtools_log_places_encodes_and_queues_message(devtools): assert queued_log_json == { "payload": { "encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ" - "21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=", + "21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=", "line_number": 0, "path": "", "timestamp": TIMESTAMP, @@ -99,12 +82,17 @@ async def test_devtools_log_spillover(devtools): # Ensure we're informing the server of spillover rate-limiting spillover_message = await devtools.log_queue.get() - assert json.loads(spillover_message) == {"type": "client_spillover", "payload": {"spillover": 2}} + assert json.loads(spillover_message) == { + "type": "client_spillover", + "payload": {"spillover": 2}, + } async def test_devtools_client_update_console_dimensions(devtools, server): - server_websocket: WebSocketResponse = next(iter(server.app["websockets"])) - # Send new server information from the server via the websocket + """Sending new server info through websocket from server to client should (eventually) + result in the dimensions of the devtools client console being updated to match. + """ + server_to_client: WebSocketResponse = next(iter(server.app["websockets"])) server_info = { "type": "server_info", "payload": { @@ -112,13 +100,7 @@ async def test_devtools_client_update_console_dimensions(devtools, server): "height": 456, }, } - await server_websocket.send_json(server_info) - timer = 0 - poll_period = .1 - while True: - if timer > 3: - pytest.fail("The devtools client dimensions did not update") - if devtools.console.size == ConsoleDimensions(123, 456): - break - await asyncio.sleep(.1) - timer += poll_period + await server_to_client.send_json(server_info) + await wait_for_predicate( + lambda: devtools.console.size == ConsoleDimensions(123, 456) + ) diff --git a/tests/utilities/render.py b/tests/utilities/render.py index 2a951f1f0..f1511e53c 100644 --- a/tests/utilities/render.py +++ b/tests/utilities/render.py @@ -1,9 +1,11 @@ +import asyncio import io import re +from typing import Callable +import pytest from rich.console import Console, RenderableType - re_link_ids = re.compile(r"id=[\d\.\-]*?;.*?\x1b") @@ -23,3 +25,32 @@ def render(renderable: RenderableType, no_wrap: bool = False) -> str: console.print(renderable, no_wrap=no_wrap, end="") output = replace_link_ids(capture.get()) return output + + +async def wait_for_predicate( + predicate: Callable[[], bool], + timeout_secs: float = 2, + poll_delay_secs: float = 0.001, +) -> None: + """Wait for the given predicate to become True by evaluating it every `poll_delay_secs` + seconds. Fail the pytest test if the predicate does not become True after `timeout_secs` + seconds. + + Args: + predicate (Callable[[], bool]): The predicate function which will be called repeatedly. + timeout_secs (float): If the predicate doesn't evaluate to True after this number of + seconds, the test will fail. + poll_delay_secs (float): The number of seconds to wait between each call to the + predicate function. + """ + time_taken = 0 + while True: + result = predicate() + if result: + return + await asyncio.sleep(poll_delay_secs) + time_taken += poll_delay_secs + if time_taken > timeout_secs: + pytest.fail( + f"Predicate {predicate} did not return True after {timeout_secs} seconds." + )