Graceful shutdown on client-side of devtools

This commit is contained in:
Darren Burns
2022-04-08 11:15:51 +01:00
parent c567b9a47e
commit a5e4634e3e
3 changed files with 69 additions and 62 deletions

View File

@@ -6,7 +6,7 @@ import json
import pickle
import sys
import weakref
from asyncio import Queue
from asyncio import Queue, Task
from datetime import datetime, timezone
from json import JSONDecodeError
from pathlib import Path
@@ -121,7 +121,7 @@ async def _enqueue_size_changes(
outgoing_queue (Queue): The Queue to add to when a size change occurs
poll_delay (int): Time between polls
shutdown_event (asyncio.Event): When set, this coroutine will stop polling
and will eventually return (after the current poll)
and will eventually return (after the current poll completes)
"""
current_width = console.width
current_height = console.height
@@ -133,10 +133,10 @@ async def _enqueue_size_changes(
await _enqueue_server_info(outgoing_queue, width, height)
current_width = width
current_height = height
await asyncio.wait(
[shutdown_event.wait(), asyncio.sleep(poll_delay)],
return_when=asyncio.FIRST_COMPLETED,
)
try:
await asyncio.wait_for(shutdown_event.wait(), timeout=poll_delay)
except asyncio.TimeoutError:
pass
async def _enqueue_server_info(
@@ -173,6 +173,7 @@ async def _consume_incoming(
while True:
message_json = await incoming_queue.get()
if message_json is None:
incoming_queue.task_done()
break
type = message_json["type"]
@@ -212,10 +213,12 @@ async def _consume_outgoing(
while True:
message_json = await outgoing_queue.get()
if message_json is None:
outgoing_queue.task_done()
break
type = message_json["type"]
if type == "server_info":
await websocket.send_json(message_json)
outgoing_queue.task_done()
async def websocket_handler(request: Request) -> WebSocketResponse:
@@ -254,13 +257,13 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
_consume_incoming(console, incoming_queue)
)
request.app["tasks"].extend(
(
consume_incoming_task,
consume_outgoing_task,
)
request.app["tasks"].update(
{
"consume_incoming_task": consume_incoming_task,
"consume_outgoing_task": consume_outgoing_task,
"size_change_task": size_change_task,
}
)
request.app["size_change_task"] = size_change_task
if request.remote:
console.print(
@@ -307,6 +310,7 @@ async def websocket_handler(request: Request) -> WebSocketResponse:
async def _on_shutdown(app: Application) -> None:
"""aiohttp shutdown handler, called when the aiohttp server is stopped"""
tasks: dict[str, Task] = app["tasks"]
# Close the websockets to stop most writes to the incoming queue
for websocket in set(app["websockets"]):
@@ -317,7 +321,7 @@ async def _on_shutdown(app: Application) -> None:
# This task needs to shut down first as it writes to the outgoing queue
shutdown_event: asyncio.Event = app["shutdown_event"]
shutdown_event.set()
size_change_task = app.get("size_change_task")
size_change_task = tasks.get("size_change_task")
if size_change_task:
await size_change_task
@@ -328,8 +332,13 @@ async def _on_shutdown(app: Application) -> None:
outgoing_queue: Queue[dict | None] = app["outgoing_queue"]
await outgoing_queue.put(None)
for task in app["tasks"]:
await task
consume_incoming_task = tasks.get("consume_incoming_task")
if consume_incoming_task:
await consume_incoming_task
consume_outgoing_task = tasks.get("consume_outgoing_task")
if consume_outgoing_task:
await consume_outgoing_task
def _run_devtools(port: int) -> None:
@@ -347,8 +356,7 @@ def _make_devtools_aiohttp_app(
app["incoming_queue"] = Queue()
app["outgoing_queue"] = Queue()
app["websockets"] = weakref.WeakSet()
app["tasks"] = []
app["size_change_task"] = None
app["tasks"] = {}
app.add_routes(
[
get("/textual-devtools-websocket", websocket_handler),

View File

@@ -6,7 +6,6 @@ import datetime
import json
import pickle
from asyncio import Queue, Task, QueueFull
from contextlib import suppress
from io import StringIO
from typing import Type, Any
@@ -44,6 +43,10 @@ class DevtoolsConnectionError(Exception):
pass
class ClientShutdown:
pass
class DevtoolsClient:
"""Client responsible for websocket communication with the devtools server.
Communicates using a simple JSON protocol.
@@ -79,7 +82,7 @@ class DevtoolsClient:
self.update_console_task: Task | None = None
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
self.websocket: ClientWebSocketResponse | None = None
self.log_queue: Queue | None = None
self.log_queue: Queue[str | Type[ClientShutdown]] | None = None
self.spillover: int = 0
async def connect(self) -> None:
@@ -90,7 +93,7 @@ class DevtoolsClient:
a connection to the server for any reason.
"""
self.session = aiohttp.ClientSession()
self.log_queue: Queue[str] = Queue(maxsize=LOG_QUEUE_MAXSIZE)
self.log_queue = Queue(maxsize=LOG_QUEUE_MAXSIZE)
try:
self.websocket = await self.session.ws_connect(
f"{self.url}/textual-devtools-websocket",
@@ -123,42 +126,42 @@ class DevtoolsClient:
"""
while True:
log = await log_queue.get()
if log is ClientShutdown:
log_queue.task_done()
break
await websocket.send_str(log)
log_queue.task_done()
self.log_queue_task = asyncio.create_task(send_queued_logs())
self.update_console_task = asyncio.create_task(update_console())
async def cancel_tasks(self) -> None:
"""Cancel client asyncio Tasks."""
await self._cancel_log_queue_processing()
await self._cancel_console_size_updates()
async def _cancel_log_queue_processing(self) -> None:
"""Cancel processing of the log queue, meaning that any messages a
async def _stop_log_queue_processing(self) -> None:
"""Schedule end of processing of the log queue, meaning that any messages a
user logs will be added to the queue, but not consumed and sent to
the server. Used for testing.
the server.
"""
if self.log_queue is not None:
await self.log_queue.put(ClientShutdown)
if self.log_queue_task:
self.log_queue_task.cancel()
with suppress(asyncio.CancelledError):
await self.log_queue_task
await self.log_queue_task
async def _cancel_console_size_updates(self) -> None:
"""Cancels the task which listens for incoming messages from the
server around changes in the server console size. Used for testing.
async def _stop_incoming_message_processing(self) -> None:
"""Schedule stop of the task which listens for incoming messages from the
server around changes in the server console size.
"""
if self.websocket:
await self.websocket.close()
if self.update_console_task:
self.update_console_task.cancel()
with suppress(asyncio.CancelledError):
await self.update_console_task
await self.update_console_task
if self.session:
await self.session.close()
async def disconnect(self) -> None:
"""Disconnect from the devtools server by cancelling tasks and
"""Disconnect from the devtools server by stopping tasks and
closing connections.
"""
await self.cancel_tasks()
await self._close_connections()
await self._stop_log_queue_processing()
await self._stop_incoming_message_processing()
@property
def is_connected(self) -> bool:
@@ -171,11 +174,6 @@ class DevtoolsClient:
return False
return not (self.session.closed or self.websocket.closed)
async def _close_connections(self) -> None:
"""Closes connect to the server"""
await self.websocket.close()
await self.session.close()
def log(self, *objects: Any, path: str = "", lineno: int = 0) -> None:
"""Queue a log to be sent to the devtools server for display.
@@ -201,20 +199,21 @@ class DevtoolsClient:
}
)
try:
self.log_queue.put_nowait(message)
if self.spillover > 0 and self.log_queue.qsize() < LOG_QUEUE_MAXSIZE:
# Tell the server how many messages we had to discard due
# to the log queue filling to capacity on the client.
spillover_message = json.dumps(
{
"type": "client_spillover",
"payload": {
"spillover": self.spillover,
},
}
)
self.log_queue.put_nowait(spillover_message)
self.spillover = 0
if self.log_queue:
self.log_queue.put_nowait(message)
if self.spillover > 0 and self.log_queue.qsize() < LOG_QUEUE_MAXSIZE:
# Tell the server how many messages we had to discard due
# to the log queue filling to capacity on the client.
spillover_message = json.dumps(
{
"type": "client_spillover",
"payload": {
"spillover": self.spillover,
},
}
)
self.log_queue.put_nowait(spillover_message)
self.spillover = 0
except QueueFull:
self.spillover += 1

View File

@@ -26,7 +26,7 @@ async def test_devtools_client_is_connected(devtools):
@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
async def test_devtools_log_places_encodes_and_queues_message(devtools):
await devtools._cancel_log_queue_processing()
await devtools._stop_log_queue_processing()
devtools.log("Hello, world!")
queued_log = await devtools.log_queue.get()
queued_log_json = json.loads(queued_log)
@@ -43,7 +43,7 @@ async def test_devtools_log_places_encodes_and_queues_message(devtools):
@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
async def test_devtools_log_places_encodes_and_queues_many_logs_as_string(devtools):
await devtools._cancel_log_queue_processing()
await devtools._stop_log_queue_processing()
devtools.log("hello", "world")
queued_log = await devtools.log_queue.get()
queued_log_json = json.loads(queued_log)
@@ -60,8 +60,8 @@ async def test_devtools_log_places_encodes_and_queues_many_logs_as_string(devtoo
async def test_devtools_log_spillover(devtools):
# Give the devtools an intentionally small max queue size
await devtools._stop_log_queue_processing()
devtools.log_queue = Queue(maxsize=2)
await devtools._cancel_log_queue_processing()
# Force spillover of 2
devtools.log(Panel("hello, world"))