Seperate server and client handling logic into classes for devtools

This commit is contained in:
Darren Burns
2022-04-11 14:53:10 +01:00
parent 678c6c60a4
commit a72e347ed9
5 changed files with 233 additions and 244 deletions

View File

@@ -1,145 +1,17 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64
import json
import pickle
import sys import sys
import weakref
from asyncio import Queue, Task
from json import JSONDecodeError
from typing import cast
from textual.devtools.client import DEFAULT_PORT
from textual.devtools.renderables import DevtoolsLogMessage, DevtoolsInternalMessage
from aiohttp import WSMessage, WSMsgType, WSCloseCode
from aiohttp.web import run_app from aiohttp.web import run_app
from aiohttp.web_app import Application from aiohttp.web_app import Application
from aiohttp.web_request import Request from aiohttp.web_request import Request
from aiohttp.web_routedef import get from aiohttp.web_routedef import get
from aiohttp.web_ws import WebSocketResponse from aiohttp.web_ws import WebSocketResponse
from rich.console import Console
from rich.markup import escape
from textual.devtools.client import DEFAULT_PORT
from textual.devtools.service import DevtoolsService
DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2 DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
async def _enqueue_size_changes(
console: Console,
outgoing_queue: Queue[dict | None],
poll_delay: int,
shutdown_event: asyncio.Event,
) -> None:
"""Poll console dimensions, and add a `server_info` message to the Queue
any time a change occurs
Args:
console (Console): The Console instance to poll for size changes on
outgoing_queue (Queue): The Queue to add to when a size change occurs
poll_delay (int): Time between polls
shutdown_event (asyncio.Event): When set, this coroutine will stop polling
and will eventually return (after the current poll completes)
"""
current_width = console.width
current_height = console.height
while not shutdown_event.is_set():
width = console.width
height = console.height
dimensions_changed = width != current_width or height != current_height
if dimensions_changed:
await _enqueue_server_info(outgoing_queue, width, height)
current_width = width
current_height = height
try:
await asyncio.wait_for(shutdown_event.wait(), timeout=poll_delay)
except asyncio.TimeoutError:
pass
async def _enqueue_server_info(
outgoing_queue: Queue[dict | None], width: int, height: int
) -> None:
"""Add `server_info` message to the queue
Args:
outgoing_queue (Queue[dict | None]): The Queue to add the message to
width (int): The new width of the server Console
height (int): The new height of the server Console
"""
await outgoing_queue.put(
{
"type": "server_info",
"payload": {
"width": width,
"height": height,
},
}
)
async def _consume_incoming(
console: Console, incoming_queue: Queue[dict | None]
) -> None:
"""Consume messages from the incoming (client -> server) Queue, and print
the corresponding renderables to the console for each message.
Args:
console (Console): The Console instance to print to
incoming_queue (Queue[dict | None]): The Queue containing messages to process
"""
while True:
message_json = await incoming_queue.get()
if message_json is None:
incoming_queue.task_done()
break
type = message_json["type"]
if type == "client_log":
path = message_json["payload"]["path"]
line_number = message_json["payload"]["line_number"]
timestamp = message_json["payload"]["timestamp"]
encoded_segments = message_json["payload"]["encoded_segments"]
decoded_segments = base64.b64decode(encoded_segments)
segments = pickle.loads(decoded_segments)
console.print(
DevtoolsLogMessage(
segments=segments,
path=path,
line_number=line_number,
unix_timestamp=timestamp,
)
)
elif type == "client_spillover":
spillover = int(message_json["payload"]["spillover"])
info_renderable = DevtoolsInternalMessage(
f"Discarded {spillover} messages", level="warning"
)
console.print(info_renderable)
incoming_queue.task_done()
async def _consume_outgoing(
outgoing_queue: Queue[dict | None], websocket: WebSocketResponse
) -> None:
"""Consume messages from the outgoing (server -> client) Queue.
Args:
outgoing_queue (Queue[dict | None]): The queue to consume from
websocket (WebSocketResponse): The websocket to write to
"""
while True:
message_json = await outgoing_queue.get()
if message_json is None:
outgoing_queue.task_done()
break
type = message_json["type"]
if type == "server_info":
await websocket.send_json(message_json)
outgoing_queue.task_done()
async def websocket_handler(request: Request) -> WebSocketResponse: async def websocket_handler(request: Request) -> WebSocketResponse:
@@ -151,115 +23,19 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
Returns: Returns:
WebSocketResponse: The websocket response WebSocketResponse: The websocket response
""" """
websocket = WebSocketResponse() service: DevtoolsService = request.app["service"]
await websocket.prepare(request) return await service.handle(request)
request.app["websockets"].add(websocket)
console = request.app["console"]
size_change_poll_delay = request.app["size_change_poll_delay_secs"]
shutdown_event: asyncio.Event = request.app["shutdown_event"]
outgoing_queue: Queue[dict | None] = request.app["outgoing_queue"]
incoming_queue: Queue[dict | None] = request.app["incoming_queue"]
size_change_task = asyncio.create_task(
_enqueue_size_changes(
console,
outgoing_queue,
poll_delay=size_change_poll_delay,
shutdown_event=shutdown_event,
)
)
consume_outgoing_task = asyncio.create_task(
_consume_outgoing(outgoing_queue, websocket)
)
consume_incoming_task = asyncio.create_task(
_consume_incoming(console, incoming_queue)
)
request.app["tasks"].update(
{
"consume_incoming_task": consume_incoming_task,
"consume_outgoing_task": consume_outgoing_task,
"size_change_task": size_change_task,
}
)
if request.remote:
console.print(
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' connected")
)
await _enqueue_server_info(
outgoing_queue, width=console.width, height=console.height
)
try:
async for message in websocket:
message = cast(WSMessage, message)
if message.type == WSMsgType.TEXT:
try:
message_json = json.loads(message.data)
except JSONDecodeError:
console.print(escape(str(message.data)))
continue
type = message_json.get("type")
if not type:
continue
if type in QUEUEABLE_TYPES and not shutdown_event.is_set():
await incoming_queue.put(message_json)
elif message.type == WSMsgType.ERROR:
console.print(
DevtoolsInternalMessage("Websocket error occurred", level="error")
)
break
except Exception as error:
console.print(DevtoolsInternalMessage(str(error), level="error"))
finally:
request.app["websockets"].discard(websocket)
console.print()
if request.remote:
console.print(
DevtoolsInternalMessage(
f"Client '{escape(request.remote)}' disconnected"
)
)
return websocket
async def _on_shutdown(app: Application) -> None: async def _on_shutdown(app: Application) -> None:
"""aiohttp shutdown handler, called when the aiohttp server is stopped""" """aiohttp shutdown handler, called when the aiohttp server is stopped"""
tasks: dict[str, Task] = app["tasks"] service: DevtoolsService = app["service"]
await service.shutdown()
# Close the websockets to stop most writes to the incoming queue
for websocket in set(app["websockets"]):
await websocket.close(
code=WSCloseCode.GOING_AWAY, message="Shutting down server"
)
# This task needs to shut down first as it writes to the outgoing queue async def _on_startup(app: Application) -> None:
shutdown_event: asyncio.Event = app["shutdown_event"] service: DevtoolsService = app["service"]
shutdown_event.set() await service.start()
size_change_task = tasks.get("size_change_task")
if size_change_task:
await size_change_task
# Now stop the tasks which read from the queues
incoming_queue: Queue[dict | None] = app["incoming_queue"]
await incoming_queue.put(None)
outgoing_queue: Queue[dict | None] = app["outgoing_queue"]
await outgoing_queue.put(None)
consume_incoming_task = tasks.get("consume_incoming_task")
if consume_incoming_task:
await consume_incoming_task
consume_outgoing_task = tasks.get("consume_outgoing_task")
if consume_outgoing_task:
await consume_outgoing_task
def _run_devtools(port: int) -> None: def _run_devtools(port: int) -> None:
@@ -271,19 +47,20 @@ def _make_devtools_aiohttp_app(
size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS, size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS,
) -> Application: ) -> Application:
app = Application() app = Application()
app["size_change_poll_delay_secs"] = size_change_poll_delay_secs
app["shutdown_event"] = asyncio.Event() app.on_shutdown.append(_on_shutdown)
app["console"] = Console() app.on_startup.append(_on_startup)
app["incoming_queue"] = Queue()
app["outgoing_queue"] = Queue() app["service"] = DevtoolsService(
app["websockets"] = weakref.WeakSet() poll_delay_seconds=size_change_poll_delay_secs,
app["tasks"] = {} )
app.add_routes( app.add_routes(
[ [
get("/textual-devtools-websocket", websocket_handler), get("/textual-devtools-websocket", websocket_handler),
] ]
) )
app.on_shutdown.append(_on_shutdown)
return app return app

View File

@@ -1,5 +1,214 @@
"""Manages a running devtools instance""" """Manages a running devtools instance"""
from __future__ import annotations
import asyncio
import base64
import json
import pickle
from json import JSONDecodeError
from typing import cast
from aiohttp import WSMessage, WSMsgType
from aiohttp.abc import Request
from aiohttp.web_ws import WebSocketResponse
from rich.console import Console
from rich.markup import escape
from textual.devtools.renderables import DevtoolsLogMessage, DevtoolsInternalMessage
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
class DevtoolsService: class DevtoolsService:
pass def __init__(self, poll_delay_seconds: float) -> None:
self.clients: list[ClientHandler] = []
self.poll_delay_seconds = poll_delay_seconds
self.shutdown_event = asyncio.Event()
self.console = Console()
async def start(self):
self.size_poll_task = asyncio.create_task(self._console_size_poller())
@property
def clients_connected(self):
return len(self.clients) > 0
async def _console_size_poller(self) -> None:
"""Poll console dimensions, and add a `server_info` message to the Queue
any time a change occurs. We only poll if there are clients connected,
and if we're not shutting down the server.
"""
current_width = self.console.width
current_height = self.console.height
while not self.shutdown_event.is_set():
width = self.console.width
height = self.console.height
dimensions_changed = width != current_width or height != current_height
if dimensions_changed:
await self._send_server_info_to_all()
current_width = width
current_height = height
try:
await asyncio.wait_for(
self.shutdown_event.wait(), timeout=self.poll_delay_seconds
)
except asyncio.TimeoutError:
pass
async def _send_server_info_to_all(self) -> None:
"""Add `server_info` message to the queues of every client"""
for client_handler in self.clients:
await self.send_server_info(client_handler)
async def send_server_info(self, client_handler: ClientHandler) -> None:
await client_handler.send_message(
{
"type": "server_info",
"payload": {
"width": self.console.width,
"height": self.console.height,
},
}
)
async def handle(self, request: Request) -> WebSocketResponse:
client = ClientHandler(request, service=self)
self.clients.append(client)
websocket = await client.start()
self.clients.remove(client)
return websocket
async def shutdown(self):
# Stop polling Console dimensions
self.shutdown_event.set()
await self.size_poll_task
# Close the websockets
for client in self.clients:
await client.close()
self.clients.clear()
class ClientHandler:
"""Handles a single client connection to the devtools.
A single DevtoolsService managers many ClientHandlers. A single ClientHandler
corresponds to a single running Textual application instance.
"""
def __init__(self, request: Request, service: DevtoolsService):
self.request = request
self.service = service
self.websocket = WebSocketResponse()
async def send_message(self, message: dict[str, object]) -> None:
await self.outgoing_queue.put(message)
async def _consume_outgoing(self) -> None:
"""Consume messages from the outgoing (server -> client) Queue."""
while True:
message_json = await self.outgoing_queue.get()
if message_json is None:
self.outgoing_queue.task_done()
break
type = message_json["type"]
if type == "server_info":
await self.websocket.send_json(message_json)
self.outgoing_queue.task_done()
async def _consume_incoming(self) -> None:
"""Consume messages from the incoming (client -> server) Queue, and print
the corresponding renderables to the console for each message.
"""
while True:
message_json = await self.incoming_queue.get()
if message_json is None:
self.incoming_queue.task_done()
break
type = message_json["type"]
if type == "client_log":
path = message_json["payload"]["path"]
line_number = message_json["payload"]["line_number"]
timestamp = message_json["payload"]["timestamp"]
encoded_segments = message_json["payload"]["encoded_segments"]
decoded_segments = base64.b64decode(encoded_segments)
segments = pickle.loads(decoded_segments)
self.service.console.print(
DevtoolsLogMessage(
segments=segments,
path=path,
line_number=line_number,
unix_timestamp=timestamp,
)
)
elif type == "client_spillover":
spillover = int(message_json["payload"]["spillover"])
info_renderable = DevtoolsInternalMessage(
f"Discarded {spillover} messages", level="warning"
)
self.service.console.print(info_renderable)
self.incoming_queue.task_done()
async def start(self) -> WebSocketResponse:
await self.websocket.prepare(self.request)
self.incoming_queue: asyncio.Queue[dict | None] = asyncio.Queue()
self.outgoing_queue: asyncio.Queue[dict | None] = asyncio.Queue()
self.outgoing_messages_task = asyncio.create_task(self._consume_outgoing())
self.incoming_messages_task = asyncio.create_task(self._consume_incoming())
self.service.console.print(
DevtoolsInternalMessage(f"Client '{escape(self.request.remote)}' connected")
)
try:
await self.service.send_server_info(client_handler=self)
async for message in self.websocket:
message = cast(WSMessage, message)
if message.type == WSMsgType.TEXT:
try:
message_json = json.loads(message.data)
except JSONDecodeError:
self.service.console.print(escape(str(message.data)))
continue
type = message_json.get("type")
if not type:
continue
if (
type in QUEUEABLE_TYPES
and not self.service.shutdown_event.is_set()
):
await self.incoming_queue.put(message_json)
elif message.type == WSMsgType.ERROR:
self.service.console.print(
DevtoolsInternalMessage(
"Websocket error occurred", level="error"
)
)
break
except Exception as error:
self.service.console.print(
DevtoolsInternalMessage(str(error), level="error")
)
finally:
self.service.console.print()
if self.request.remote:
self.service.console.print(
DevtoolsInternalMessage(
f"Client '{escape(self.request.remote)}' disconnected"
)
)
await self.close()
return self.websocket
async def close(self) -> None:
# Stop any writes to the websocket first
await self.outgoing_queue.put(None)
await self.outgoing_messages_task
# Now we can shut the socket down
await self.websocket.close()
# This task is independent of the websocket
await self.incoming_queue.put(None)
await self.incoming_messages_task

View File

@@ -2,6 +2,7 @@ import pytest
from textual.devtools.server import _make_devtools_aiohttp_app from textual.devtools.server import _make_devtools_aiohttp_app
from textual.devtools.client import DevtoolsClient from textual.devtools.client import DevtoolsClient
from textual.devtools.service import DevtoolsService
@pytest.fixture @pytest.fixture
@@ -10,7 +11,9 @@ async def server(aiohttp_server, unused_tcp_port):
size_change_poll_delay_secs=0.001, size_change_poll_delay_secs=0.001,
) )
server = await aiohttp_server(app, port=unused_tcp_port) server = await aiohttp_server(app, port=unused_tcp_port)
service: DevtoolsService = app["service"]
yield server yield server
await service.shutdown()
await server.close() await server.close()

View File

@@ -92,6 +92,6 @@ async def test_devtools_spillover_message(devtools):
async def test_devtools_console_size_change(server, devtools): async def test_devtools_console_size_change(server, devtools):
# Update the width of the console on the server-side # Update the width of the console on the server-side
server.app["console"].width = 124 server.app["service"].console.width = 124
# Wait for the client side to update the console on their end # Wait for the client side to update the console on their end
await wait_for_predicate(lambda: devtools.console.width == 124) await wait_for_predicate(lambda: devtools.console.width == 124)

View File

@@ -89,7 +89,7 @@ async def test_devtools_client_update_console_dimensions(devtools, server):
"""Sending new server info through websocket from server to client should (eventually) """Sending new server info through websocket from server to client should (eventually)
result in the dimensions of the devtools client console being updated to match. result in the dimensions of the devtools client console being updated to match.
""" """
server_to_client: WebSocketResponse = next(iter(server.app["websockets"])) server_to_client: WebSocketResponse = next(iter(server.app["service"].clients)).websocket
server_info = { server_info = {
"type": "server_info", "type": "server_info",
"payload": { "payload": {