Additional testing for devtools client/server

This commit is contained in:
Darren Burns
2022-04-06 18:00:32 +01:00
parent 1067be927f
commit 5648894560
7 changed files with 207 additions and 60 deletions

View File

@@ -30,7 +30,7 @@ from ._event_broker import extract_handler_actions, NoHandler
from ._profile import timer
from .binding import Bindings, NoBinding
from .css.stylesheet import Stylesheet, StylesheetParseError, StylesheetError
from .devtools_client import DevtoolsClient
from .devtools_client import DevtoolsClient, DevtoolsConnectionError
from .dom import DOMNode
from .driver import Driver
from .file_monitor import FileMonitor
@@ -415,10 +415,15 @@ class App(DOMNode):
log(f"driver={self.driver_class}")
if os.getenv("TEXTUAL_DEVTOOLS") == "1":
try:
await self.devtools.connect()
self.log_file.write(f"Connected to devtools ({self.devtools.url})\n")
self.log_file.flush()
except DevtoolsConnectionError:
self.log_file.write(
f"Couldn't connect to devtools ({self.devtools.url})\n"
)
self.log_file.flush()
try:
if self.css_file is not None:
self.stylesheet.read(self.css_file)
@@ -456,6 +461,7 @@ class App(DOMNode):
await self.animator.start()
await super().process_messages()
log("PROCESS END")
if self.devtools.is_connected:
await self._disconnect_devtools()
with timer("animator.stop()"):
await self.animator.stop()

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import base64
import json
@@ -20,12 +22,13 @@ from aiohttp.web_ws import WebSocketResponse
from dateutil.tz import tz
from rich.align import Align
from rich.console import Console, ConsoleOptions, RenderResult
from rich.markup import escape
from rich.rule import Rule
from rich.segment import Segments, Segment
from rich.table import Table
DEFAULT_PORT = 8081
SIZE_CHANGE_POLL_DELAY_SECONDS = 2
DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
@@ -82,7 +85,9 @@ class DevtoolsInternalMessage:
yield Rule(self.message, style=level_to_style.get(self.level, "dim"))
async def enqueue_size_changes(console: Console, outgoing_queue: Queue):
async def _enqueue_size_changes(
console: Console, outgoing_queue: Queue, poll_delay: int
):
current_width = console.width
current_height = console.height
while True:
@@ -90,13 +95,13 @@ async def enqueue_size_changes(console: Console, outgoing_queue: Queue):
height = console.height
dimensions_changed = width != current_width or height != current_height
if dimensions_changed:
await enqueue_server_info(outgoing_queue, width, height)
await _enqueue_server_info(outgoing_queue, width, height)
current_width = width
current_height = height
await asyncio.sleep(SIZE_CHANGE_POLL_DELAY_SECONDS)
await asyncio.sleep(poll_delay)
async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None:
async def _enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None:
await outgoing_queue.put(
{
"type": "server_info",
@@ -108,7 +113,7 @@ async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) ->
)
async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None:
async def _consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None:
while True:
message_json = await incoming_queue.get()
type = message_json["type"]
@@ -137,7 +142,7 @@ async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> Non
incoming_queue.task_done()
async def consume_outgoing(
async def _consume_outgoing(
outgoing_queue: Queue[dict], websocket: WebSocketResponse
) -> None:
while True:
@@ -153,21 +158,28 @@ async def websocket_handler(request: Request):
request.app["websockets"].add(websocket)
console = request.app["console"]
size_change_poll_delay = request.app["size_change_poll_delay_secs"]
incoming_queue: Queue[dict] = Queue()
outgoing_queue: Queue[dict] = Queue()
request.app["tasks"].extend(
(
asyncio.create_task(consume_outgoing(outgoing_queue, websocket)),
asyncio.create_task(enqueue_size_changes(console, outgoing_queue)),
asyncio.create_task(consume_incoming(console, incoming_queue)),
asyncio.create_task(_consume_outgoing(outgoing_queue, websocket)),
asyncio.create_task(
_enqueue_size_changes(
console, outgoing_queue, poll_delay=size_change_poll_delay
)
),
asyncio.create_task(_consume_incoming(console, incoming_queue)),
)
)
console.print(DevtoolsInternalMessage(f"Client '{request.remote}' connected"))
console.print(
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' connected")
)
await enqueue_server_info(
await _enqueue_server_info(
outgoing_queue, width=console.width, height=console.height
)
try:
@@ -177,7 +189,7 @@ async def websocket_handler(request: Request):
try:
message_json = json.loads(message.data)
except JSONDecodeError:
console.print(f"{message.data}")
console.print(escape(str(message.data)))
continue
type = message_json.get("type")
@@ -196,13 +208,13 @@ async def websocket_handler(request: Request):
request.app["websockets"].discard(websocket)
console.print()
console.print(
DevtoolsInternalMessage(f"Client '{request.remote}' disconnected")
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' disconnected")
)
return websocket
async def on_shutdown(app: Application) -> None:
async def _on_shutdown(app: Application) -> None:
for task in app["tasks"]:
task.cancel()
with suppress(CancelledError):
@@ -214,13 +226,16 @@ async def on_shutdown(app: Application) -> None:
)
def run_devtools(port: int) -> None:
app = make_aiohttp_app()
def _run_devtools(port: int) -> None:
app = _make_devtools_aiohttp_app()
run_app(app, port=port)
def make_aiohttp_app():
def _make_devtools_aiohttp_app(
size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS,
):
app = Application()
app["size_change_poll_delay_secs"] = size_change_poll_delay_secs
app["console"] = Console()
app["websockets"] = weakref.WeakSet()
app["tasks"] = []
@@ -229,7 +244,7 @@ def make_aiohttp_app():
get("/textual-devtools-websocket", websocket_handler),
]
)
app.on_shutdown.append(on_shutdown)
app.on_shutdown.append(_on_shutdown)
return app
@@ -238,4 +253,4 @@ if __name__ == "__main__":
port = int(sys.argv[1])
else:
port = DEFAULT_PORT
run_devtools(port)
_run_devtools(port)

View File

@@ -44,6 +44,7 @@ class DevtoolsConnectionError(Exception):
class DevtoolsClient:
def __init__(self, address: str = "127.0.0.1", port: int = DEFAULT_PORT):
self.url: str = f"ws://{address}:{port}"
self.session: aiohttp.ClientSession | None = None
self.log_queue_task: Task | None = None
self.update_console_task: Task | None = None
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
@@ -52,7 +53,7 @@ class DevtoolsClient:
self.spillover: int = 0
async def connect(self) -> None:
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
self.session = aiohttp.ClientSession()
self.log_queue: Queue[str | Type[DetachDevtools]] = Queue(
maxsize=LOG_QUEUE_MAXSIZE
)

24
tests/conftest.py Normal file
View File

@@ -0,0 +1,24 @@
import pytest
from textual.devtools import _make_devtools_aiohttp_app
from textual.devtools_client import DevtoolsClient
@pytest.fixture
async def server(aiohttp_server, unused_tcp_port):
app = _make_devtools_aiohttp_app(
size_change_poll_delay_secs=0.001,
)
server = await aiohttp_server(app, port=unused_tcp_port)
yield server
await server.close()
@pytest.fixture
async def devtools(aiohttp_client, server):
client = await aiohttp_client(server)
devtools = DevtoolsClient(address=client.host, port=client.port)
await devtools.connect()
yield devtools
await devtools.disconnect()
await client.close()

88
tests/test_devtools.py Normal file
View File

@@ -0,0 +1,88 @@
import asyncio
import pytest
import time_machine
from rich.align import Align
from rich.console import Console
from rich.segment import Segment
from tests.utilities.render import wait_for_predicate
from textual.devtools import DevtoolsLogMessage, DevtoolsInternalMessage
TIMESTAMP = 1649166819
WIDTH = 40
# The string "Hello, world!" is encoded in the payload below
EXAMPLE_LOG = {
"type": "client_log",
"payload": {
"encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ"
"21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=",
"line_number": 123,
"path": "abc/hello.py",
"timestamp": TIMESTAMP,
},
}
@pytest.fixture(scope="module")
def console():
return Console(width=WIDTH)
@time_machine.travel(TIMESTAMP)
def test_log_message_render(console):
message = DevtoolsLogMessage(
[Segment("content")],
path="abc/hello.py",
line_number=123,
unix_timestamp=TIMESTAMP,
)
table = next(iter(message.__rich_console__(console, console.options)))
assert len(table.rows) == 1
columns = list(table.columns)
left_cells = list(columns[0].cells)
left = left_cells[0]
right_cells = list(columns[1].cells)
right: Align = right_cells[0]
assert left == " [#888177]15:53:39 [dim]BST[/]"
assert right.align == "right"
assert "hello.py:123" in right.renderable
def test_internal_message_render(console):
message = DevtoolsInternalMessage("hello")
rule = next(iter(message.__rich_console__(console, console.options)))
assert rule.title == "hello"
assert rule.characters == ""
async def test_devtools_valid_client_log(devtools):
await devtools.websocket.send_json(EXAMPLE_LOG)
assert devtools.is_connected
async def test_devtools_string_not_json_message(devtools):
await devtools.websocket.send_str("ABCDEFG")
assert devtools.is_connected
async def test_devtools_invalid_json_message(devtools):
await devtools.websocket.send_json({"invalid": "json"})
assert devtools.is_connected
async def test_devtools_spillover_message(devtools):
await devtools.websocket.send_json(
{"type": "client_spillover", "payload": {"spillover": 123}}
)
assert devtools.is_connected
async def test_devtools_console_size_change(server, devtools):
# Update the width of the console on the server-side
server.app["console"].width = 124
# Wait for the client side to update the console on their end
await wait_for_predicate(lambda: devtools.console.width == 124)

View File

@@ -9,29 +9,12 @@ from aiohttp.web_ws import WebSocketResponse
from rich.console import ConsoleDimensions
from rich.panel import Panel
from textual.devtools import make_aiohttp_app
from tests.utilities.render import wait_for_predicate
from textual.devtools_client import DevtoolsClient
TIMESTAMP = 1649166819
@pytest.fixture
async def server(aiohttp_server, unused_tcp_port):
server = await aiohttp_server(make_aiohttp_app(), port=unused_tcp_port)
yield server
await server.close()
@pytest.fixture
async def devtools(aiohttp_client, server):
client = await aiohttp_client(server)
devtools = DevtoolsClient(address=client.host, port=client.port)
await devtools.connect()
yield devtools
await devtools.disconnect()
await client.close()
def test_devtools_client_initialize_defaults():
devtools = DevtoolsClient()
assert devtools.url == "ws://127.0.0.1:8081"
@@ -99,12 +82,17 @@ async def test_devtools_log_spillover(devtools):
# Ensure we're informing the server of spillover rate-limiting
spillover_message = await devtools.log_queue.get()
assert json.loads(spillover_message) == {"type": "client_spillover", "payload": {"spillover": 2}}
assert json.loads(spillover_message) == {
"type": "client_spillover",
"payload": {"spillover": 2},
}
async def test_devtools_client_update_console_dimensions(devtools, server):
server_websocket: WebSocketResponse = next(iter(server.app["websockets"]))
# Send new server information from the server via the websocket
"""Sending new server info through websocket from server to client should (eventually)
result in the dimensions of the devtools client console being updated to match.
"""
server_to_client: WebSocketResponse = next(iter(server.app["websockets"]))
server_info = {
"type": "server_info",
"payload": {
@@ -112,13 +100,7 @@ async def test_devtools_client_update_console_dimensions(devtools, server):
"height": 456,
},
}
await server_websocket.send_json(server_info)
timer = 0
poll_period = .1
while True:
if timer > 3:
pytest.fail("The devtools client dimensions did not update")
if devtools.console.size == ConsoleDimensions(123, 456):
break
await asyncio.sleep(.1)
timer += poll_period
await server_to_client.send_json(server_info)
await wait_for_predicate(
lambda: devtools.console.size == ConsoleDimensions(123, 456)
)

View File

@@ -1,9 +1,11 @@
import asyncio
import io
import re
from typing import Callable
import pytest
from rich.console import Console, RenderableType
re_link_ids = re.compile(r"id=[\d\.\-]*?;.*?\x1b")
@@ -23,3 +25,32 @@ def render(renderable: RenderableType, no_wrap: bool = False) -> str:
console.print(renderable, no_wrap=no_wrap, end="")
output = replace_link_ids(capture.get())
return output
async def wait_for_predicate(
predicate: Callable[[], bool],
timeout_secs: float = 2,
poll_delay_secs: float = 0.001,
) -> None:
"""Wait for the given predicate to become True by evaluating it every `poll_delay_secs`
seconds. Fail the pytest test if the predicate does not become True after `timeout_secs`
seconds.
Args:
predicate (Callable[[], bool]): The predicate function which will be called repeatedly.
timeout_secs (float): If the predicate doesn't evaluate to True after this number of
seconds, the test will fail.
poll_delay_secs (float): The number of seconds to wait between each call to the
predicate function.
"""
time_taken = 0
while True:
result = predicate()
if result:
return
await asyncio.sleep(poll_delay_secs)
time_taken += poll_delay_secs
if time_taken > timeout_secs:
pytest.fail(
f"Predicate {predicate} did not return True after {timeout_secs} seconds."
)