mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Graceful shutdown on client-side of devtools
This commit is contained in:
@@ -6,7 +6,7 @@ import json
|
||||
import pickle
|
||||
import sys
|
||||
import weakref
|
||||
from asyncio import Queue
|
||||
from asyncio import Queue, Task
|
||||
from datetime import datetime, timezone
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
@@ -121,7 +121,7 @@ async def _enqueue_size_changes(
|
||||
outgoing_queue (Queue): The Queue to add to when a size change occurs
|
||||
poll_delay (int): Time between polls
|
||||
shutdown_event (asyncio.Event): When set, this coroutine will stop polling
|
||||
and will eventually return (after the current poll)
|
||||
and will eventually return (after the current poll completes)
|
||||
"""
|
||||
current_width = console.width
|
||||
current_height = console.height
|
||||
@@ -133,10 +133,10 @@ async def _enqueue_size_changes(
|
||||
await _enqueue_server_info(outgoing_queue, width, height)
|
||||
current_width = width
|
||||
current_height = height
|
||||
await asyncio.wait(
|
||||
[shutdown_event.wait(), asyncio.sleep(poll_delay)],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(shutdown_event.wait(), timeout=poll_delay)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
async def _enqueue_server_info(
|
||||
@@ -173,6 +173,7 @@ async def _consume_incoming(
|
||||
while True:
|
||||
message_json = await incoming_queue.get()
|
||||
if message_json is None:
|
||||
incoming_queue.task_done()
|
||||
break
|
||||
|
||||
type = message_json["type"]
|
||||
@@ -212,10 +213,12 @@ async def _consume_outgoing(
|
||||
while True:
|
||||
message_json = await outgoing_queue.get()
|
||||
if message_json is None:
|
||||
outgoing_queue.task_done()
|
||||
break
|
||||
type = message_json["type"]
|
||||
if type == "server_info":
|
||||
await websocket.send_json(message_json)
|
||||
outgoing_queue.task_done()
|
||||
|
||||
|
||||
async def websocket_handler(request: Request) -> WebSocketResponse:
|
||||
@@ -254,13 +257,13 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
|
||||
_consume_incoming(console, incoming_queue)
|
||||
)
|
||||
|
||||
request.app["tasks"].extend(
|
||||
(
|
||||
consume_incoming_task,
|
||||
consume_outgoing_task,
|
||||
)
|
||||
request.app["tasks"].update(
|
||||
{
|
||||
"consume_incoming_task": consume_incoming_task,
|
||||
"consume_outgoing_task": consume_outgoing_task,
|
||||
"size_change_task": size_change_task,
|
||||
}
|
||||
)
|
||||
request.app["size_change_task"] = size_change_task
|
||||
|
||||
if request.remote:
|
||||
console.print(
|
||||
@@ -307,6 +310,7 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
|
||||
|
||||
async def _on_shutdown(app: Application) -> None:
|
||||
"""aiohttp shutdown handler, called when the aiohttp server is stopped"""
|
||||
tasks: dict[str, Task] = app["tasks"]
|
||||
|
||||
# Close the websockets to stop most writes to the incoming queue
|
||||
for websocket in set(app["websockets"]):
|
||||
@@ -317,7 +321,7 @@ async def _on_shutdown(app: Application) -> None:
|
||||
# This task needs to shut down first as it writes to the outgoing queue
|
||||
shutdown_event: asyncio.Event = app["shutdown_event"]
|
||||
shutdown_event.set()
|
||||
size_change_task = app.get("size_change_task")
|
||||
size_change_task = tasks.get("size_change_task")
|
||||
if size_change_task:
|
||||
await size_change_task
|
||||
|
||||
@@ -328,8 +332,13 @@ async def _on_shutdown(app: Application) -> None:
|
||||
outgoing_queue: Queue[dict | None] = app["outgoing_queue"]
|
||||
await outgoing_queue.put(None)
|
||||
|
||||
for task in app["tasks"]:
|
||||
await task
|
||||
consume_incoming_task = tasks.get("consume_incoming_task")
|
||||
if consume_incoming_task:
|
||||
await consume_incoming_task
|
||||
|
||||
consume_outgoing_task = tasks.get("consume_outgoing_task")
|
||||
if consume_outgoing_task:
|
||||
await consume_outgoing_task
|
||||
|
||||
|
||||
def _run_devtools(port: int) -> None:
|
||||
@@ -347,8 +356,7 @@ def _make_devtools_aiohttp_app(
|
||||
app["incoming_queue"] = Queue()
|
||||
app["outgoing_queue"] = Queue()
|
||||
app["websockets"] = weakref.WeakSet()
|
||||
app["tasks"] = []
|
||||
app["size_change_task"] = None
|
||||
app["tasks"] = {}
|
||||
app.add_routes(
|
||||
[
|
||||
get("/textual-devtools-websocket", websocket_handler),
|
||||
|
||||
@@ -6,7 +6,6 @@ import datetime
|
||||
import json
|
||||
import pickle
|
||||
from asyncio import Queue, Task, QueueFull
|
||||
from contextlib import suppress
|
||||
from io import StringIO
|
||||
from typing import Type, Any
|
||||
|
||||
@@ -44,6 +43,10 @@ class DevtoolsConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ClientShutdown:
|
||||
pass
|
||||
|
||||
|
||||
class DevtoolsClient:
|
||||
"""Client responsible for websocket communication with the devtools server.
|
||||
Communicates using a simple JSON protocol.
|
||||
@@ -79,7 +82,7 @@ class DevtoolsClient:
|
||||
self.update_console_task: Task | None = None
|
||||
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
|
||||
self.websocket: ClientWebSocketResponse | None = None
|
||||
self.log_queue: Queue | None = None
|
||||
self.log_queue: Queue[str | Type[ClientShutdown]] | None = None
|
||||
self.spillover: int = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
@@ -90,7 +93,7 @@ class DevtoolsClient:
|
||||
a connection to the server for any reason.
|
||||
"""
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.log_queue: Queue[str] = Queue(maxsize=LOG_QUEUE_MAXSIZE)
|
||||
self.log_queue = Queue(maxsize=LOG_QUEUE_MAXSIZE)
|
||||
try:
|
||||
self.websocket = await self.session.ws_connect(
|
||||
f"{self.url}/textual-devtools-websocket",
|
||||
@@ -123,42 +126,42 @@ class DevtoolsClient:
|
||||
"""
|
||||
while True:
|
||||
log = await log_queue.get()
|
||||
if log is ClientShutdown:
|
||||
log_queue.task_done()
|
||||
break
|
||||
await websocket.send_str(log)
|
||||
log_queue.task_done()
|
||||
|
||||
self.log_queue_task = asyncio.create_task(send_queued_logs())
|
||||
self.update_console_task = asyncio.create_task(update_console())
|
||||
|
||||
async def cancel_tasks(self) -> None:
|
||||
"""Cancel client asyncio Tasks."""
|
||||
await self._cancel_log_queue_processing()
|
||||
await self._cancel_console_size_updates()
|
||||
|
||||
async def _cancel_log_queue_processing(self) -> None:
|
||||
"""Cancel processing of the log queue, meaning that any messages a
|
||||
async def _stop_log_queue_processing(self) -> None:
|
||||
"""Schedule end of processing of the log queue, meaning that any messages a
|
||||
user logs will be added to the queue, but not consumed and sent to
|
||||
the server. Used for testing.
|
||||
the server.
|
||||
"""
|
||||
if self.log_queue is not None:
|
||||
await self.log_queue.put(ClientShutdown)
|
||||
if self.log_queue_task:
|
||||
self.log_queue_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self.log_queue_task
|
||||
await self.log_queue_task
|
||||
|
||||
async def _cancel_console_size_updates(self) -> None:
|
||||
"""Cancels the task which listens for incoming messages from the
|
||||
server around changes in the server console size. Used for testing.
|
||||
async def _stop_incoming_message_processing(self) -> None:
|
||||
"""Schedule stop of the task which listens for incoming messages from the
|
||||
server around changes in the server console size.
|
||||
"""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
if self.update_console_task:
|
||||
self.update_console_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await self.update_console_task
|
||||
await self.update_console_task
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the devtools server by cancelling tasks and
|
||||
"""Disconnect from the devtools server by stopping tasks and
|
||||
closing connections.
|
||||
"""
|
||||
await self.cancel_tasks()
|
||||
await self._close_connections()
|
||||
await self._stop_log_queue_processing()
|
||||
await self._stop_incoming_message_processing()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -171,11 +174,6 @@ class DevtoolsClient:
|
||||
return False
|
||||
return not (self.session.closed or self.websocket.closed)
|
||||
|
||||
async def _close_connections(self) -> None:
|
||||
"""Closes connect to the server"""
|
||||
await self.websocket.close()
|
||||
await self.session.close()
|
||||
|
||||
def log(self, *objects: Any, path: str = "", lineno: int = 0) -> None:
|
||||
"""Queue a log to be sent to the devtools server for display.
|
||||
|
||||
@@ -201,20 +199,21 @@ class DevtoolsClient:
|
||||
}
|
||||
)
|
||||
try:
|
||||
self.log_queue.put_nowait(message)
|
||||
if self.spillover > 0 and self.log_queue.qsize() < LOG_QUEUE_MAXSIZE:
|
||||
# Tell the server how many messages we had to discard due
|
||||
# to the log queue filling to capacity on the client.
|
||||
spillover_message = json.dumps(
|
||||
{
|
||||
"type": "client_spillover",
|
||||
"payload": {
|
||||
"spillover": self.spillover,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.log_queue.put_nowait(spillover_message)
|
||||
self.spillover = 0
|
||||
if self.log_queue:
|
||||
self.log_queue.put_nowait(message)
|
||||
if self.spillover > 0 and self.log_queue.qsize() < LOG_QUEUE_MAXSIZE:
|
||||
# Tell the server how many messages we had to discard due
|
||||
# to the log queue filling to capacity on the client.
|
||||
spillover_message = json.dumps(
|
||||
{
|
||||
"type": "client_spillover",
|
||||
"payload": {
|
||||
"spillover": self.spillover,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.log_queue.put_nowait(spillover_message)
|
||||
self.spillover = 0
|
||||
except QueueFull:
|
||||
self.spillover += 1
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ async def test_devtools_client_is_connected(devtools):
|
||||
|
||||
@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
|
||||
async def test_devtools_log_places_encodes_and_queues_message(devtools):
|
||||
await devtools._cancel_log_queue_processing()
|
||||
await devtools._stop_log_queue_processing()
|
||||
devtools.log("Hello, world!")
|
||||
queued_log = await devtools.log_queue.get()
|
||||
queued_log_json = json.loads(queued_log)
|
||||
@@ -43,7 +43,7 @@ async def test_devtools_log_places_encodes_and_queues_message(devtools):
|
||||
|
||||
@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
|
||||
async def test_devtools_log_places_encodes_and_queues_many_logs_as_string(devtools):
|
||||
await devtools._cancel_log_queue_processing()
|
||||
await devtools._stop_log_queue_processing()
|
||||
devtools.log("hello", "world")
|
||||
queued_log = await devtools.log_queue.get()
|
||||
queued_log_json = json.loads(queued_log)
|
||||
@@ -60,8 +60,8 @@ async def test_devtools_log_places_encodes_and_queues_many_logs_as_string(devtoo
|
||||
|
||||
async def test_devtools_log_spillover(devtools):
|
||||
# Give the devtools an intentionally small max queue size
|
||||
await devtools._stop_log_queue_processing()
|
||||
devtools.log_queue = Queue(maxsize=2)
|
||||
await devtools._cancel_log_queue_processing()
|
||||
|
||||
# Force spillover of 2
|
||||
devtools.log(Panel("hello, world"))
|
||||
|
||||
Reference in New Issue
Block a user