Refactor shutdown procedure

This commit is contained in:
Will McGugan
2021-06-20 11:50:44 +01:00
parent ab14b766d3
commit 0dd46641e4
12 changed files with 215 additions and 330 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Coroutine, Awaitable, NamedTuple
import asyncio
from asyncio import Event, PriorityQueue, Task, QueueEmpty
from asyncio import Event, Queue, Task, QueueEmpty
import logging
@@ -14,50 +14,21 @@ from ._types import MessageHandler
log = logging.getLogger("rich")
class MessageQueueItem(NamedTuple):
priority: int
message: Message
def __lt__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority < other_priority
def __le__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority <= other_priority
def __gt__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority > other_priority
def __ge__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority >= other_priority
def __eq__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority == other_priority
def __ne__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority != other_priority
class MessagePumpClosed(Exception):
pass
class MessagePump:
def __init__(self, queue_size: int = 10, parent: MessagePump | None = None) -> None:
self._message_queue: PriorityQueue[MessageQueueItem | None] = PriorityQueue(
queue_size
)
self._message_queue: Queue[Message | None] = Queue()
self._parent = parent
self._closing: bool = False
self._closed: bool = False
self._disabled_messages: set[type[Message]] = set()
self._pending_message: MessageQueueItem | None = None
self._pending_message: Message | None = None
self._task: Task | None = None
self._child_tasks: set[Task] = set()
self._queue_empty_event = Event()
@property
def task(self) -> Task:
@@ -78,26 +49,26 @@ class MessagePump:
"""Enable processing of messages types."""
self._disabled_messages.difference_update(messages)
async def get_message(self) -> MessageQueueItem:
async def get_message(self) -> Message:
"""Get the next event on the queue, or None if queue is closed.
Returns:
Optional[Event]: Event object or None.
"""
if self._closed:
raise MessagePumpClosed("The message pump is closed")
if self._pending_message is not None:
try:
return self._pending_message
finally:
self._pending_message = None
if self._closed:
raise MessagePumpClosed("The message pump is closed")
queue_item = await self._message_queue.get()
if queue_item is None:
message = await self._message_queue.get()
if message is None:
self._closed = True
raise MessagePumpClosed("The message pump is now closed")
return queue_item
return message
def peek_message(self) -> MessageQueueItem | None:
def peek_message(self) -> Message | None:
"""Peek the message at the head of the queue (does not remove it from the queue),
or return None if the queue is empty.
@@ -122,7 +93,8 @@ class MessagePump:
callback: TimerCallback = None,
) -> Timer:
timer = Timer(self, delay, self, name=name, callback=callback, repeat=0)
asyncio.get_event_loop().create_task(timer.run())
timer_task = asyncio.get_event_loop().create_task(timer.run())
self._child_tasks.add(timer_task)
return timer
def set_interval(
@@ -139,91 +111,131 @@ class MessagePump:
asyncio.get_event_loop().create_task(timer.run())
return timer
async def close_messages(self, wait: bool = False) -> None:
async def stop_messages(self) -> None:
if not self._closing:
await self.post_message(events.NoneEvent(self))
self._closing = True
return
if not (self._closing or self._closed):
self._queue_empty_event.clear()
await self.post_message(events.NoneEvent(self))
self._closing = True
await self._queue_empty_event.wait()
self._queue_empty_event.clear()
async def close_messages(self, wait: bool = True) -> None:
"""Close message queue, and optionally wait for queue to finish processing."""
if self._closed:
return
log.debug("close_messages %r wait=%r", self, wait)
self._closing = True
log.debug("close 1 %r", self)
for task in self._child_tasks:
task.cancel()
log.debug("close 2 %r", self)
await self._message_queue.put(None)
log.debug("close 3 %r", self)
if wait and self._task is not None:
await self._task
self._task = None
log.debug("close 4 %r", self)
def start_messages(self) -> None:
task = asyncio.create_task(self.process_messages())
self._task = task
self._task = asyncio.create_task(self.process_messages())
async def process_messages(self) -> None:
"""Process messages until the queue is closed."""
while not self._closed:
try:
priority, message = await self.get_message()
message = await self.get_message()
except MessagePumpClosed:
log.debug("CLOSED %r", self)
break
except Exception as error:
log.exception("error in get_message()")
raise error from None
log.debug("%r -> %r", message, self)
# Combine any pending messages that may supersede this one
while True:
pending = self.peek_message()
if pending is None or not message.can_batch(pending.message):
if pending is None or not message.can_batch(pending):
break
try:
message = await self.get_message()
except MessagePumpClosed:
break
priority, message = await self.get_message()
try:
await self.dispatch_message(message, priority)
await self.dispatch_message(message)
except Exception as error:
log.exception("error in dispatch_message")
raise
finally:
log.debug("a")
if self._message_queue.empty():
idle_handler = getattr(self, "on_idle", None)
if idle_handler is not None:
await idle_handler(events.Idle(self))
log.debug("b")
self._queue_empty_event.set()
if not self._closed:
idle_handler = getattr(self, "on_idle", None)
log.debug("c %r", idle_handler)
if idle_handler is not None and not self._closed:
log.debug("d")
await idle_handler(events.Idle(self))
log.debug("e")
self._queue_empty_event.set()
async def dispatch_message(
self, message: Message, priority: int = 0
) -> bool | None:
async def dispatch_message(self, message: Message) -> bool | None:
log.debug("dispatch_message %r", message)
if isinstance(message, events.Event):
await self.on_event(message, priority)
await self.on_event(message)
else:
return await self.on_message(message)
return False
async def on_event(self, event: events.Event, priority: int) -> None:
async def on_event(self, event: events.Event) -> None:
method_name = f"on_{event.name}"
dispatch_function: MessageHandler = getattr(self, method_name, None)
log.debug("dispatching to %r", dispatch_function)
if dispatch_function is not None:
await dispatch_function(event)
if event.bubble and self._parent and not event._stop_propagaton:
if event.sender == self._parent:
log.debug("bubbled event abandoned; %r", event)
else:
await self._parent.post_message(event, priority)
elif not self._parent._closed and not self._parent._closing:
await self._parent.post_message(event)
async def on_message(self, message: Message) -> None:
pass
async def post_message(
self,
message: Message,
priority: int | None = None,
) -> bool:
def post_message_no_wait(self, message: Message) -> bool:
if self._closing or self._closed:
return False
if not self.check_message_enabled(message):
return True
event_priority = priority if priority is not None else message.default_priority
item = MessageQueueItem(event_priority, message)
await self._message_queue.put(item)
self._message_queue.put_nowait(message)
return True
async def post_message_from_child(
self, message: Message, priority: int | None = None
) -> None:
await self.post_message(message, priority=priority)
async def post_message(self, message: Message) -> bool:
log.debug("%r post_message 1", self)
if self._closing or self._closed:
return False
log.debug("%r post_message 2", self)
if not self.check_message_enabled(message):
return True
log.debug("%r post_message 3", self)
await self._message_queue.put(message)
log.debug("%r post_message 4", self)
return True
async def emit(self, message: Message, priority: int | None = None) -> bool:
async def post_message_from_child(self, message: Message) -> bool:
if self._closing or self._closed:
return False
return await self.post_message(message)
async def emit(self, message: Message) -> bool:
if self._parent:
await self._parent.post_message_from_child(message, priority=priority)
await self._parent.post_message_from_child(message)
return True
else:
return False