mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Seperate server and client handling logic into classes for devtools
This commit is contained in:
@@ -1,145 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
import weakref
|
||||
from asyncio import Queue, Task
|
||||
from json import JSONDecodeError
|
||||
from typing import cast
|
||||
|
||||
from textual.devtools.client import DEFAULT_PORT
|
||||
from textual.devtools.renderables import DevtoolsLogMessage, DevtoolsInternalMessage
|
||||
|
||||
|
||||
from aiohttp import WSMessage, WSMsgType, WSCloseCode
|
||||
from aiohttp.web import run_app
|
||||
from aiohttp.web_app import Application
|
||||
from aiohttp.web_request import Request
|
||||
from aiohttp.web_routedef import get
|
||||
from aiohttp.web_ws import WebSocketResponse
|
||||
from rich.console import Console
|
||||
from rich.markup import escape
|
||||
|
||||
from textual.devtools.client import DEFAULT_PORT
|
||||
from textual.devtools.service import DevtoolsService
|
||||
|
||||
DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2
|
||||
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
|
||||
|
||||
|
||||
async def _enqueue_size_changes(
|
||||
console: Console,
|
||||
outgoing_queue: Queue[dict | None],
|
||||
poll_delay: int,
|
||||
shutdown_event: asyncio.Event,
|
||||
) -> None:
|
||||
"""Poll console dimensions, and add a `server_info` message to the Queue
|
||||
any time a change occurs
|
||||
|
||||
Args:
|
||||
console (Console): The Console instance to poll for size changes on
|
||||
outgoing_queue (Queue): The Queue to add to when a size change occurs
|
||||
poll_delay (int): Time between polls
|
||||
shutdown_event (asyncio.Event): When set, this coroutine will stop polling
|
||||
and will eventually return (after the current poll completes)
|
||||
"""
|
||||
current_width = console.width
|
||||
current_height = console.height
|
||||
while not shutdown_event.is_set():
|
||||
width = console.width
|
||||
height = console.height
|
||||
dimensions_changed = width != current_width or height != current_height
|
||||
if dimensions_changed:
|
||||
await _enqueue_server_info(outgoing_queue, width, height)
|
||||
current_width = width
|
||||
current_height = height
|
||||
try:
|
||||
await asyncio.wait_for(shutdown_event.wait(), timeout=poll_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
async def _enqueue_server_info(
|
||||
outgoing_queue: Queue[dict | None], width: int, height: int
|
||||
) -> None:
|
||||
"""Add `server_info` message to the queue
|
||||
|
||||
Args:
|
||||
outgoing_queue (Queue[dict | None]): The Queue to add the message to
|
||||
width (int): The new width of the server Console
|
||||
height (int): The new height of the server Console
|
||||
"""
|
||||
await outgoing_queue.put(
|
||||
{
|
||||
"type": "server_info",
|
||||
"payload": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _consume_incoming(
|
||||
console: Console, incoming_queue: Queue[dict | None]
|
||||
) -> None:
|
||||
"""Consume messages from the incoming (client -> server) Queue, and print
|
||||
the corresponding renderables to the console for each message.
|
||||
|
||||
Args:
|
||||
console (Console): The Console instance to print to
|
||||
incoming_queue (Queue[dict | None]): The Queue containing messages to process
|
||||
"""
|
||||
while True:
|
||||
message_json = await incoming_queue.get()
|
||||
if message_json is None:
|
||||
incoming_queue.task_done()
|
||||
break
|
||||
|
||||
type = message_json["type"]
|
||||
if type == "client_log":
|
||||
path = message_json["payload"]["path"]
|
||||
line_number = message_json["payload"]["line_number"]
|
||||
timestamp = message_json["payload"]["timestamp"]
|
||||
encoded_segments = message_json["payload"]["encoded_segments"]
|
||||
decoded_segments = base64.b64decode(encoded_segments)
|
||||
segments = pickle.loads(decoded_segments)
|
||||
console.print(
|
||||
DevtoolsLogMessage(
|
||||
segments=segments,
|
||||
path=path,
|
||||
line_number=line_number,
|
||||
unix_timestamp=timestamp,
|
||||
)
|
||||
)
|
||||
elif type == "client_spillover":
|
||||
spillover = int(message_json["payload"]["spillover"])
|
||||
info_renderable = DevtoolsInternalMessage(
|
||||
f"Discarded {spillover} messages", level="warning"
|
||||
)
|
||||
console.print(info_renderable)
|
||||
incoming_queue.task_done()
|
||||
|
||||
|
||||
async def _consume_outgoing(
|
||||
outgoing_queue: Queue[dict | None], websocket: WebSocketResponse
|
||||
) -> None:
|
||||
"""Consume messages from the outgoing (server -> client) Queue.
|
||||
|
||||
Args:
|
||||
outgoing_queue (Queue[dict | None]): The queue to consume from
|
||||
websocket (WebSocketResponse): The websocket to write to
|
||||
"""
|
||||
while True:
|
||||
message_json = await outgoing_queue.get()
|
||||
if message_json is None:
|
||||
outgoing_queue.task_done()
|
||||
break
|
||||
type = message_json["type"]
|
||||
if type == "server_info":
|
||||
await websocket.send_json(message_json)
|
||||
outgoing_queue.task_done()
|
||||
|
||||
|
||||
async def websocket_handler(request: Request) -> WebSocketResponse:
|
||||
@@ -151,115 +23,19 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
|
||||
Returns:
|
||||
WebSocketResponse: The websocket response
|
||||
"""
|
||||
websocket = WebSocketResponse()
|
||||
await websocket.prepare(request)
|
||||
request.app["websockets"].add(websocket)
|
||||
|
||||
console = request.app["console"]
|
||||
|
||||
size_change_poll_delay = request.app["size_change_poll_delay_secs"]
|
||||
shutdown_event: asyncio.Event = request.app["shutdown_event"]
|
||||
|
||||
outgoing_queue: Queue[dict | None] = request.app["outgoing_queue"]
|
||||
incoming_queue: Queue[dict | None] = request.app["incoming_queue"]
|
||||
|
||||
size_change_task = asyncio.create_task(
|
||||
_enqueue_size_changes(
|
||||
console,
|
||||
outgoing_queue,
|
||||
poll_delay=size_change_poll_delay,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
)
|
||||
consume_outgoing_task = asyncio.create_task(
|
||||
_consume_outgoing(outgoing_queue, websocket)
|
||||
)
|
||||
consume_incoming_task = asyncio.create_task(
|
||||
_consume_incoming(console, incoming_queue)
|
||||
)
|
||||
|
||||
request.app["tasks"].update(
|
||||
{
|
||||
"consume_incoming_task": consume_incoming_task,
|
||||
"consume_outgoing_task": consume_outgoing_task,
|
||||
"size_change_task": size_change_task,
|
||||
}
|
||||
)
|
||||
|
||||
if request.remote:
|
||||
console.print(
|
||||
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' connected")
|
||||
)
|
||||
|
||||
await _enqueue_server_info(
|
||||
outgoing_queue, width=console.width, height=console.height
|
||||
)
|
||||
try:
|
||||
async for message in websocket:
|
||||
message = cast(WSMessage, message)
|
||||
if message.type == WSMsgType.TEXT:
|
||||
try:
|
||||
message_json = json.loads(message.data)
|
||||
except JSONDecodeError:
|
||||
console.print(escape(str(message.data)))
|
||||
continue
|
||||
|
||||
type = message_json.get("type")
|
||||
if not type:
|
||||
continue
|
||||
if type in QUEUEABLE_TYPES and not shutdown_event.is_set():
|
||||
await incoming_queue.put(message_json)
|
||||
elif message.type == WSMsgType.ERROR:
|
||||
console.print(
|
||||
DevtoolsInternalMessage("Websocket error occurred", level="error")
|
||||
)
|
||||
break
|
||||
except Exception as error:
|
||||
console.print(DevtoolsInternalMessage(str(error), level="error"))
|
||||
finally:
|
||||
request.app["websockets"].discard(websocket)
|
||||
console.print()
|
||||
if request.remote:
|
||||
console.print(
|
||||
DevtoolsInternalMessage(
|
||||
f"Client '{escape(request.remote)}' disconnected"
|
||||
)
|
||||
)
|
||||
|
||||
return websocket
|
||||
service: DevtoolsService = request.app["service"]
|
||||
return await service.handle(request)
|
||||
|
||||
|
||||
async def _on_shutdown(app: Application) -> None:
|
||||
"""aiohttp shutdown handler, called when the aiohttp server is stopped"""
|
||||
tasks: dict[str, Task] = app["tasks"]
|
||||
service: DevtoolsService = app["service"]
|
||||
await service.shutdown()
|
||||
|
||||
# Close the websockets to stop most writes to the incoming queue
|
||||
for websocket in set(app["websockets"]):
|
||||
await websocket.close(
|
||||
code=WSCloseCode.GOING_AWAY, message="Shutting down server"
|
||||
)
|
||||
|
||||
# This task needs to shut down first as it writes to the outgoing queue
|
||||
shutdown_event: asyncio.Event = app["shutdown_event"]
|
||||
shutdown_event.set()
|
||||
size_change_task = tasks.get("size_change_task")
|
||||
if size_change_task:
|
||||
await size_change_task
|
||||
|
||||
# Now stop the tasks which read from the queues
|
||||
incoming_queue: Queue[dict | None] = app["incoming_queue"]
|
||||
await incoming_queue.put(None)
|
||||
|
||||
outgoing_queue: Queue[dict | None] = app["outgoing_queue"]
|
||||
await outgoing_queue.put(None)
|
||||
|
||||
consume_incoming_task = tasks.get("consume_incoming_task")
|
||||
if consume_incoming_task:
|
||||
await consume_incoming_task
|
||||
|
||||
consume_outgoing_task = tasks.get("consume_outgoing_task")
|
||||
if consume_outgoing_task:
|
||||
await consume_outgoing_task
|
||||
async def _on_startup(app: Application) -> None:
|
||||
service: DevtoolsService = app["service"]
|
||||
await service.start()
|
||||
|
||||
|
||||
def _run_devtools(port: int) -> None:
|
||||
@@ -271,19 +47,20 @@ def _make_devtools_aiohttp_app(
|
||||
size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS,
|
||||
) -> Application:
|
||||
app = Application()
|
||||
app["size_change_poll_delay_secs"] = size_change_poll_delay_secs
|
||||
app["shutdown_event"] = asyncio.Event()
|
||||
app["console"] = Console()
|
||||
app["incoming_queue"] = Queue()
|
||||
app["outgoing_queue"] = Queue()
|
||||
app["websockets"] = weakref.WeakSet()
|
||||
app["tasks"] = {}
|
||||
|
||||
app.on_shutdown.append(_on_shutdown)
|
||||
app.on_startup.append(_on_startup)
|
||||
|
||||
app["service"] = DevtoolsService(
|
||||
poll_delay_seconds=size_change_poll_delay_secs,
|
||||
)
|
||||
|
||||
app.add_routes(
|
||||
[
|
||||
get("/textual-devtools-websocket", websocket_handler),
|
||||
]
|
||||
)
|
||||
app.on_shutdown.append(_on_shutdown)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,214 @@
|
||||
"""Manages a running devtools instance"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import pickle
|
||||
from json import JSONDecodeError
|
||||
from typing import cast
|
||||
|
||||
from aiohttp import WSMessage, WSMsgType
|
||||
from aiohttp.abc import Request
|
||||
from aiohttp.web_ws import WebSocketResponse
|
||||
from rich.console import Console
|
||||
from rich.markup import escape
|
||||
|
||||
from textual.devtools.renderables import DevtoolsLogMessage, DevtoolsInternalMessage
|
||||
|
||||
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
|
||||
|
||||
|
||||
class DevtoolsService:
|
||||
pass
|
||||
def __init__(self, poll_delay_seconds: float) -> None:
|
||||
self.clients: list[ClientHandler] = []
|
||||
self.poll_delay_seconds = poll_delay_seconds
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.console = Console()
|
||||
|
||||
async def start(self):
|
||||
self.size_poll_task = asyncio.create_task(self._console_size_poller())
|
||||
|
||||
@property
|
||||
def clients_connected(self):
|
||||
return len(self.clients) > 0
|
||||
|
||||
async def _console_size_poller(self) -> None:
|
||||
"""Poll console dimensions, and add a `server_info` message to the Queue
|
||||
any time a change occurs. We only poll if there are clients connected,
|
||||
and if we're not shutting down the server.
|
||||
"""
|
||||
current_width = self.console.width
|
||||
current_height = self.console.height
|
||||
while not self.shutdown_event.is_set():
|
||||
width = self.console.width
|
||||
height = self.console.height
|
||||
dimensions_changed = width != current_width or height != current_height
|
||||
if dimensions_changed:
|
||||
await self._send_server_info_to_all()
|
||||
current_width = width
|
||||
current_height = height
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.shutdown_event.wait(), timeout=self.poll_delay_seconds
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def _send_server_info_to_all(self) -> None:
|
||||
"""Add `server_info` message to the queues of every client"""
|
||||
for client_handler in self.clients:
|
||||
await self.send_server_info(client_handler)
|
||||
|
||||
async def send_server_info(self, client_handler: ClientHandler) -> None:
|
||||
await client_handler.send_message(
|
||||
{
|
||||
"type": "server_info",
|
||||
"payload": {
|
||||
"width": self.console.width,
|
||||
"height": self.console.height,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def handle(self, request: Request) -> WebSocketResponse:
|
||||
client = ClientHandler(request, service=self)
|
||||
self.clients.append(client)
|
||||
websocket = await client.start()
|
||||
self.clients.remove(client)
|
||||
return websocket
|
||||
|
||||
async def shutdown(self):
|
||||
# Stop polling Console dimensions
|
||||
self.shutdown_event.set()
|
||||
await self.size_poll_task
|
||||
|
||||
# Close the websockets
|
||||
for client in self.clients:
|
||||
await client.close()
|
||||
self.clients.clear()
|
||||
|
||||
|
||||
class ClientHandler:
|
||||
"""Handles a single client connection to the devtools.
|
||||
A single DevtoolsService managers many ClientHandlers. A single ClientHandler
|
||||
corresponds to a single running Textual application instance.
|
||||
"""
|
||||
|
||||
def __init__(self, request: Request, service: DevtoolsService):
|
||||
self.request = request
|
||||
self.service = service
|
||||
self.websocket = WebSocketResponse()
|
||||
|
||||
async def send_message(self, message: dict[str, object]) -> None:
|
||||
await self.outgoing_queue.put(message)
|
||||
|
||||
async def _consume_outgoing(self) -> None:
|
||||
"""Consume messages from the outgoing (server -> client) Queue."""
|
||||
while True:
|
||||
message_json = await self.outgoing_queue.get()
|
||||
if message_json is None:
|
||||
self.outgoing_queue.task_done()
|
||||
break
|
||||
type = message_json["type"]
|
||||
if type == "server_info":
|
||||
await self.websocket.send_json(message_json)
|
||||
self.outgoing_queue.task_done()
|
||||
|
||||
async def _consume_incoming(self) -> None:
|
||||
"""Consume messages from the incoming (client -> server) Queue, and print
|
||||
the corresponding renderables to the console for each message.
|
||||
"""
|
||||
while True:
|
||||
message_json = await self.incoming_queue.get()
|
||||
if message_json is None:
|
||||
self.incoming_queue.task_done()
|
||||
break
|
||||
|
||||
type = message_json["type"]
|
||||
if type == "client_log":
|
||||
path = message_json["payload"]["path"]
|
||||
line_number = message_json["payload"]["line_number"]
|
||||
timestamp = message_json["payload"]["timestamp"]
|
||||
encoded_segments = message_json["payload"]["encoded_segments"]
|
||||
decoded_segments = base64.b64decode(encoded_segments)
|
||||
segments = pickle.loads(decoded_segments)
|
||||
self.service.console.print(
|
||||
DevtoolsLogMessage(
|
||||
segments=segments,
|
||||
path=path,
|
||||
line_number=line_number,
|
||||
unix_timestamp=timestamp,
|
||||
)
|
||||
)
|
||||
elif type == "client_spillover":
|
||||
spillover = int(message_json["payload"]["spillover"])
|
||||
info_renderable = DevtoolsInternalMessage(
|
||||
f"Discarded {spillover} messages", level="warning"
|
||||
)
|
||||
self.service.console.print(info_renderable)
|
||||
self.incoming_queue.task_done()
|
||||
|
||||
async def start(self) -> WebSocketResponse:
|
||||
await self.websocket.prepare(self.request)
|
||||
self.incoming_queue: asyncio.Queue[dict | None] = asyncio.Queue()
|
||||
self.outgoing_queue: asyncio.Queue[dict | None] = asyncio.Queue()
|
||||
self.outgoing_messages_task = asyncio.create_task(self._consume_outgoing())
|
||||
self.incoming_messages_task = asyncio.create_task(self._consume_incoming())
|
||||
|
||||
self.service.console.print(
|
||||
DevtoolsInternalMessage(f"Client '{escape(self.request.remote)}' connected")
|
||||
)
|
||||
try:
|
||||
await self.service.send_server_info(client_handler=self)
|
||||
async for message in self.websocket:
|
||||
message = cast(WSMessage, message)
|
||||
if message.type == WSMsgType.TEXT:
|
||||
try:
|
||||
message_json = json.loads(message.data)
|
||||
except JSONDecodeError:
|
||||
self.service.console.print(escape(str(message.data)))
|
||||
continue
|
||||
|
||||
type = message_json.get("type")
|
||||
if not type:
|
||||
continue
|
||||
if (
|
||||
type in QUEUEABLE_TYPES
|
||||
and not self.service.shutdown_event.is_set()
|
||||
):
|
||||
await self.incoming_queue.put(message_json)
|
||||
elif message.type == WSMsgType.ERROR:
|
||||
self.service.console.print(
|
||||
DevtoolsInternalMessage(
|
||||
"Websocket error occurred", level="error"
|
||||
)
|
||||
)
|
||||
break
|
||||
except Exception as error:
|
||||
self.service.console.print(
|
||||
DevtoolsInternalMessage(str(error), level="error")
|
||||
)
|
||||
finally:
|
||||
self.service.console.print()
|
||||
if self.request.remote:
|
||||
self.service.console.print(
|
||||
DevtoolsInternalMessage(
|
||||
f"Client '{escape(self.request.remote)}' disconnected"
|
||||
)
|
||||
)
|
||||
await self.close()
|
||||
|
||||
return self.websocket
|
||||
|
||||
async def close(self) -> None:
|
||||
# Stop any writes to the websocket first
|
||||
await self.outgoing_queue.put(None)
|
||||
await self.outgoing_messages_task
|
||||
|
||||
# Now we can shut the socket down
|
||||
await self.websocket.close()
|
||||
|
||||
# This task is independent of the websocket
|
||||
await self.incoming_queue.put(None)
|
||||
await self.incoming_messages_task
|
||||
|
||||
@@ -2,6 +2,7 @@ import pytest
|
||||
|
||||
from textual.devtools.server import _make_devtools_aiohttp_app
|
||||
from textual.devtools.client import DevtoolsClient
|
||||
from textual.devtools.service import DevtoolsService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -10,7 +11,9 @@ async def server(aiohttp_server, unused_tcp_port):
|
||||
size_change_poll_delay_secs=0.001,
|
||||
)
|
||||
server = await aiohttp_server(app, port=unused_tcp_port)
|
||||
service: DevtoolsService = app["service"]
|
||||
yield server
|
||||
await service.shutdown()
|
||||
await server.close()
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,6 @@ async def test_devtools_spillover_message(devtools):
|
||||
|
||||
async def test_devtools_console_size_change(server, devtools):
|
||||
# Update the width of the console on the server-side
|
||||
server.app["console"].width = 124
|
||||
server.app["service"].console.width = 124
|
||||
# Wait for the client side to update the console on their end
|
||||
await wait_for_predicate(lambda: devtools.console.width == 124)
|
||||
|
||||
@@ -89,7 +89,7 @@ async def test_devtools_client_update_console_dimensions(devtools, server):
|
||||
"""Sending new server info through websocket from server to client should (eventually)
|
||||
result in the dimensions of the devtools client console being updated to match.
|
||||
"""
|
||||
server_to_client: WebSocketResponse = next(iter(server.app["websockets"]))
|
||||
server_to_client: WebSocketResponse = next(iter(server.app["service"].clients)).websocket
|
||||
server_info = {
|
||||
"type": "server_info",
|
||||
"payload": {
|
||||
|
||||
Reference in New Issue
Block a user