Refactor shutdown procedure

This commit is contained in:
Will McGugan
2021-06-20 11:50:44 +01:00
parent ab14b766d3
commit 0dd46641e4
12 changed files with 215 additions and 330 deletions

View File

@@ -4,6 +4,19 @@ version = "0.1.1"
description = "Text User Interface using Rich"
authors = ["Will McGugan <willmcgugan@gmail.com>"]
license = "MIT"
classifiers = [
"Development Status :: 1 - Planning",
"Environment :: Console",
"Intended Audience :: Developers",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = "^3.7"

View File

@@ -139,15 +139,21 @@ class LinuxDriver(Driver):
| termios.IGNCR
)
def stop_application_mode(self) -> None:
log.debug("stop_application_mode()")
def disable_input(self) -> None:
try:
if not self.exit_event.is_set():
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
self._disable_mouse_support()
self.exit_event.set()
if self._key_thread is not None:
self._key_thread.join()
except Exception:
log.exception("error in disable_input")
def stop_application_mode(self) -> None:
log.debug("stop_application_mode()")
self.disable_input()
if self.attrs_before is not None:
try:
@@ -190,12 +196,18 @@ class LinuxDriver(Driver):
read = os.read
log.debug("started key thread")
try:
while not self.exit_event.is_set():
selector_events = selector.select(0.1)
for _selector_key, mask in selector_events:
if mask | selectors.EVENT_READ:
unicode_data = decode(read(fileno, 1024))
for event in parser.feed(unicode_data):
send_event(event)
except Exception:
log.exception("error running key thread")
finally:
selector.close()
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import weakref
from asyncio import Event, TimeoutError, wait_for
from asyncio import CancelledError, Event, TimeoutError, wait_for
from time import monotonic
from typing import Awaitable, Callable
@@ -62,6 +62,7 @@ class Timer:
_interval = self._interval
_wait = self._stop_event.wait
start = monotonic()
try:
while _repeat is None or count <= _repeat:
next_timer = start + (count * _interval)
try:
@@ -77,3 +78,5 @@ class Timer:
except EventTargetGone:
break
count += 1
except CancelledError:
pass

View File

@@ -16,20 +16,12 @@ Callback = Callable[[], None]
class MessageTarget(Protocol):
async def post_message(
self,
message: "Message",
priority: Optional[int] = None,
) -> bool:
async def post_message(self, message: "Message") -> bool:
...
class EventTarget(Protocol):
async def post_message(
self,
message: "Message",
priority: Optional[int] = None,
) -> bool:
async def post_message(self, message: "Message") -> bool:
...

View File

@@ -56,17 +56,18 @@ class App(MessagePump):
self,
console: Console = None,
screen: bool = True,
driver: Type[Driver] = None,
driver_class: Type[Driver] = None,
view: View = None,
title: str = "Megasoma Application",
):
super().__init__()
self.console = console or get_console()
self._screen = screen
self.driver = driver or LinuxDriver
self.driver_class = driver_class or LinuxDriver
self.title = title
self.view = view or LayoutView()
self.children: set[MessagePump] = set()
self._driver: Driver | None = None
self._action_targets = {"app": self, "view": self.view}
@@ -78,7 +79,7 @@ class App(MessagePump):
cls, console: Console = None, screen: bool = True, driver: Type[Driver] = None
):
async def run_app() -> None:
app = cls(console=console, screen=screen, driver=driver)
app = cls(console=console, screen=screen, driver_class=driver)
await app.process_messages()
asyncio.run(run_app())
@@ -95,11 +96,11 @@ class App(MessagePump):
self.console.print_exception(show_locals=True)
async def _process_messages(self) -> None:
log.debug("driver=%r", self.driver)
log.debug("driver=%r", self.driver_class)
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, self.on_keyboard_interupt)
driver = self.driver(self.console, self)
driver = self._driver = self.driver_class(self.console, self)
active_app.set(self)
@@ -114,23 +115,6 @@ class App(MessagePump):
raise
try:
await super().process_messages()
finally:
try:
if self.children:
async def close_all() -> None:
for child in self.children:
await child.close_messages()
await asyncio.gather(*(child.task for child in self.children))
try:
await asyncio.wait_for(close_all(), timeout=5)
except asyncio.TimeoutError as error:
raise ShutdownError(
"Timeout closing messages pump(s)"
) from None
self.children.clear()
finally:
try:
driver.stop_application_mode()
@@ -142,6 +126,44 @@ class App(MessagePump):
child.start_messages()
await child.post_message(events.Created(sender=self))
async def remove(self, child: MessagePump) -> None:
self.children.remove(child)
async def shutdown(self):
driver = self._driver
driver.disable_input()
async def shutdown_procedure() -> None:
log.debug("1")
await self.stop_messages()
log.debug("2")
await self.view.stop_messages()
log.debug("3")
log.debug("4")
await self.remove(self.view)
if self.children:
log.debug("5")
async def close_all() -> None:
for child in self.children:
await child.close_messages(wait=False)
await asyncio.gather(*(child.task for child in self.children))
try:
await asyncio.wait_for(close_all(), timeout=5)
log.debug("6")
except asyncio.TimeoutError as error:
raise ShutdownError("Timeout closing messages pump(s)") from None
log.debug("7")
log.debug("8")
await self.view.close_messages()
log.debug("9")
await self.close_messages()
log.debug("10")
await asyncio.create_task(shutdown_procedure())
def refresh(self) -> None:
console = self.console
try:
@@ -150,7 +172,7 @@ class App(MessagePump):
except Exception:
log.exception("refresh failed")
async def on_event(self, event: events.Event, priority: int) -> None:
async def on_event(self, event: events.Event) -> None:
if isinstance(event, events.Key):
key_action = self.KEYS.get(event.key, None)
if key_action is not None:
@@ -160,7 +182,7 @@ class App(MessagePump):
if isinstance(event, events.InputEvent):
await self.view.forward_input_event(event)
else:
await super().on_event(event, priority)
await super().on_event(event)
async def on_idle(self, event: events.Idle) -> None:
await self.view.post_message(event)
@@ -215,7 +237,7 @@ class App(MessagePump):
await self.view.post_message(event)
async def action_quit(self) -> None:
await self.close_messages()
await self.shutdown()
async def action_bang(self) -> None:
1 / 0

View File

@@ -1,18 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import logging
import os
import signal
import curses
import platform
import sys
import shutil
from threading import Event, Thread
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from . import events
from ._types import MessageTarget
if TYPE_CHECKING:
@@ -34,159 +26,9 @@ class Driver(ABC):
...
@abstractmethod
def stop_application_mode(self) -> None:
def disable_input(self) -> None:
...
# class CursesDriver(Driver):
# _MOUSE_PRESSED = [
# curses.BUTTON1_PRESSED,
# curses.BUTTON2_PRESSED,
# curses.BUTTON3_PRESSED,
# curses.BUTTON4_PRESSED,
# ]
# _MOUSE_RELEASED = [
# curses.BUTTON1_RELEASED,
# curses.BUTTON2_RELEASED,
# curses.BUTTON3_RELEASED,
# curses.BUTTON4_RELEASED,
# ]
# _MOUSE_CLICKED = [
# curses.BUTTON1_CLICKED,
# curses.BUTTON2_CLICKED,
# curses.BUTTON3_CLICKED,
# curses.BUTTON4_CLICKED,
# ]
# _MOUSE_DOUBLE_CLICKED = [
# curses.BUTTON1_DOUBLE_CLICKED,
# curses.BUTTON2_DOUBLE_CLICKED,
# curses.BUTTON3_DOUBLE_CLICKED,
# curses.BUTTON4_DOUBLE_CLICKED,
# ]
# _MOUSE = [
# (events.MouseDown, _MOUSE_PRESSED),
# (events.MouseUp, _MOUSE_RELEASED),
# (events.Click, _MOUSE_CLICKED),
# (events.DoubleClick, _MOUSE_DOUBLE_CLICKED),
# ]
# def __init__(self, console: "Console", target: "MessageTarget") -> None:
# super().__init__(console, target)
# self._stdscr = None
# self._exit_event = Event()
# self._key_thread: Thread | None = None
# def _get_terminal_size(self) -> tuple[int, int]:
# width: int | None = 80
# height: int | None = 25
# if WINDOWS: # pragma: no cover
# width, height = shutil.get_terminal_size()
# else:
# try:
# width, height = os.get_terminal_size(sys.stdin.fileno())
# except (AttributeError, ValueError, OSError):
# try:
# width, height = os.get_terminal_size(sys.stdout.fileno())
# except (AttributeError, ValueError, OSError):
# pass
# width = width or 80
# height = height or 25
# return width, height
# def start_application_mode(self):
# loop = asyncio.get_event_loop()
# def on_terminal_resize(signum, stack) -> None:
# terminal_size = self._get_terminal_size()
# width, height = terminal_size
# event = events.Resize(self._target, width, height)
# self.console.size = terminal_size
# asyncio.run_coroutine_threadsafe(
# self._target.post_message(event),
# loop=loop,
# )
# signal.signal(signal.SIGWINCH, on_terminal_resize)
# self._stdscr = curses.initscr()
# curses.noecho()
# curses.cbreak()
# curses.halfdelay(1)
# curses.mousemask(curses.REPORT_MOUSE_POSITION | curses.ALL_MOUSE_EVENTS)
# # curses.mousemask(-1)
# self._stdscr.keypad(True)
# self.console.show_cursor(False)
# self.console.file.write("\033[?1003h\n")
# self._key_thread = Thread(
# target=self.run_key_thread, args=(asyncio.get_event_loop(),)
# )
# width, height = self.console.size = self._get_terminal_size()
# asyncio.run_coroutine_threadsafe(
# self._target.post_message(events.Resize(self._target, width, height)),
# loop=loop,
# )
# self._key_thread.start()
# def stop_application_mode(self):
# signal.signal(signal.SIGWINCH, signal.SIG_DFL)
# self._exit_event.set()
# self._key_thread.join()
# curses.nocbreak()
# self._stdscr.keypad(False)
# curses.echo()
# curses.endwin()
# self.console.show_cursor(True)
# def run_key_thread(self, loop) -> None:
# stdscr = self._stdscr
# assert stdscr is not None
# exit_event = self._exit_event
# def send_event(event: events.Event) -> None:
# asyncio.run_coroutine_threadsafe(
# self._target.post_message(event),
# loop=loop,
# )
# while not exit_event.is_set():
# code = stdscr.getch()
# if code == -1:
# continue
# if code == curses.KEY_MOUSE:
# try:
# _id, x, y, _z, button_state = curses.getmouse()
# except Exception:
# log.exception("error in curses.getmouse")
# else:
# if button_state & curses.REPORT_MOUSE_POSITION:
# send_event(events.MouseMove(self._target, x, y))
# alt = bool(button_state & curses.BUTTON_ALT)
# ctrl = bool(button_state & curses.BUTTON_CTRL)
# shift = bool(button_state & curses.BUTTON_SHIFT)
# for event_type, masks in self._MOUSE:
# for button, mask in enumerate(masks, 1):
# if button_state & mask:
# send_event(
# event_type(
# self._target,
# x,
# y,
# button,
# alt=alt,
# ctrl=ctrl,
# shift=shift,
# )
# )
# else:
# send_event(events.Key(self._target, code=code))
@abstractmethod
def stop_application_mode(self) -> None:
...

View File

@@ -50,7 +50,6 @@ class EventType(Enum):
CUSTOM = 1000
@rich_repr
class Event(Message):
type: ClassVar[EventType]
@@ -58,30 +57,28 @@ class Event(Message):
return
yield
def __init_subclass__(
cls, type: EventType, priority: int = 0, bubble: bool = False
) -> None:
def __init_subclass__(cls, type: EventType, bubble: bool = False) -> None:
cls.type = type
super().__init_subclass__(priority=priority, bubble=bubble)
super().__init_subclass__(bubble=bubble)
# def __enter__(self) -> "Event":
# return self
# def __exit__(self, exc_type, exc_value, exc_tb) -> bool | None:
# if exc_type is not None:
# # Log and suppress exception
# return True
class NoneEvent(Event, type=EventType.NONE):
pass
class ShutdownRequest(Event, type=EventType.SHUTDOWN_REQUEST):
pass
class Load(Event, type=EventType.SHUTDOWN_REQUEST):
class Shutdown(Event, type=EventType.SHUTDOWN):
pass
class Startup(Event, type=EventType.SHUTDOWN_REQUEST):
class Load(Event, type=EventType.LOAD):
pass
class Startup(Event, type=EventType.STARTUP):
pass
@@ -120,10 +117,6 @@ class Unmount(Event, type=EventType.UNMOUNT):
pass
class Shutdown(Event, type=EventType.SHUTDOWN):
pass
class InputEvent(Event, type=EventType.NONE, bubble=True):
pass
@@ -205,7 +198,7 @@ class DoubleClick(MouseEvent, type=EventType.DOUBLE_CLICK):
@rich_repr
class Timer(Event, type=EventType.TIMER, priority=10):
class Timer(Event, type=EventType.TIMER):
__slots__ = ["time", "count", "callback"]
def __init__(

View File

@@ -7,7 +7,6 @@ from .case import camel_to_snake
from ._types import MessageTarget
@rich_repr
class Message:
"""Base class for a message."""
@@ -21,7 +20,6 @@ class Message:
sender: MessageTarget
bubble: ClassVar[bool] = False
default_priority: ClassVar[int] = 0
def __init__(self, sender: MessageTarget) -> None:
self.sender = sender
@@ -35,10 +33,9 @@ class Message:
return
yield
def __init_subclass__(cls, bubble: bool = False, priority: int = 0) -> None:
def __init_subclass__(cls, bubble: bool = False) -> None:
super().__init_subclass__()
cls.bubble = bubble
cls.default_priority = priority
def can_batch(self, message: "Message") -> bool:
"""Check if another message may supersede this one.

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Coroutine, Awaitable, NamedTuple
import asyncio
from asyncio import Event, PriorityQueue, Task, QueueEmpty
from asyncio import Event, Queue, Task, QueueEmpty
import logging
@@ -14,50 +14,21 @@ from ._types import MessageHandler
log = logging.getLogger("rich")
class MessageQueueItem(NamedTuple):
priority: int
message: Message
def __lt__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority < other_priority
def __le__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority <= other_priority
def __gt__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority > other_priority
def __ge__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority >= other_priority
def __eq__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority == other_priority
def __ne__(self, other: object) -> bool:
other_priority = other.priority if isinstance(other, MessageQueueItem) else 0
return self.priority != other_priority
class MessagePumpClosed(Exception):
pass
class MessagePump:
def __init__(self, queue_size: int = 10, parent: MessagePump | None = None) -> None:
self._message_queue: PriorityQueue[MessageQueueItem | None] = PriorityQueue(
queue_size
)
self._message_queue: Queue[Message | None] = Queue()
self._parent = parent
self._closing: bool = False
self._closed: bool = False
self._disabled_messages: set[type[Message]] = set()
self._pending_message: MessageQueueItem | None = None
self._pending_message: Message | None = None
self._task: Task | None = None
self._child_tasks: set[Task] = set()
self._queue_empty_event = Event()
@property
def task(self) -> Task:
@@ -78,26 +49,26 @@ class MessagePump:
"""Enable processing of messages types."""
self._disabled_messages.difference_update(messages)
async def get_message(self) -> MessageQueueItem:
async def get_message(self) -> Message:
"""Get the next event on the queue, or None if queue is closed.
Returns:
Optional[Event]: Event object or None.
"""
if self._closed:
raise MessagePumpClosed("The message pump is closed")
if self._pending_message is not None:
try:
return self._pending_message
finally:
self._pending_message = None
if self._closed:
raise MessagePumpClosed("The message pump is closed")
queue_item = await self._message_queue.get()
if queue_item is None:
message = await self._message_queue.get()
if message is None:
self._closed = True
raise MessagePumpClosed("The message pump is now closed")
return queue_item
return message
def peek_message(self) -> MessageQueueItem | None:
def peek_message(self) -> Message | None:
"""Peek the message at the head of the queue (does not remove it from the queue),
or return None if the queue is empty.
@@ -122,7 +93,8 @@ class MessagePump:
callback: TimerCallback = None,
) -> Timer:
timer = Timer(self, delay, self, name=name, callback=callback, repeat=0)
asyncio.get_event_loop().create_task(timer.run())
timer_task = asyncio.get_event_loop().create_task(timer.run())
self._child_tasks.add(timer_task)
return timer
def set_interval(
@@ -139,91 +111,131 @@ class MessagePump:
asyncio.get_event_loop().create_task(timer.run())
return timer
async def close_messages(self, wait: bool = False) -> None:
"""Close message queue, and optionally wait for queue to finish processing."""
async def stop_messages(self) -> None:
if not self._closing:
await self.post_message(events.NoneEvent(self))
self._closing = True
return
if not (self._closing or self._closed):
self._queue_empty_event.clear()
await self.post_message(events.NoneEvent(self))
self._closing = True
await self._queue_empty_event.wait()
self._queue_empty_event.clear()
async def close_messages(self, wait: bool = True) -> None:
"""Close message queue, and optionally wait for queue to finish processing."""
if self._closed:
return
log.debug("close_messages %r wait=%r", self, wait)
self._closing = True
log.debug("close 1 %r", self)
for task in self._child_tasks:
task.cancel()
log.debug("close 2 %r", self)
await self._message_queue.put(None)
log.debug("close 3 %r", self)
if wait and self._task is not None:
await self._task
self._task = None
log.debug("close 4 %r", self)
def start_messages(self) -> None:
task = asyncio.create_task(self.process_messages())
self._task = task
self._task = asyncio.create_task(self.process_messages())
async def process_messages(self) -> None:
"""Process messages until the queue is closed."""
while not self._closed:
try:
priority, message = await self.get_message()
message = await self.get_message()
except MessagePumpClosed:
log.debug("CLOSED %r", self)
break
except Exception as error:
log.exception("error in get_message()")
raise error from None
log.debug("%r -> %r", message, self)
# Combine any pending messages that may supersede this one
while True:
pending = self.peek_message()
if pending is None or not message.can_batch(pending.message):
if pending is None or not message.can_batch(pending):
break
try:
message = await self.get_message()
except MessagePumpClosed:
break
priority, message = await self.get_message()
try:
await self.dispatch_message(message, priority)
await self.dispatch_message(message)
except Exception as error:
log.exception("error in dispatch_message")
raise
finally:
log.debug("a")
if self._message_queue.empty():
log.debug("b")
self._queue_empty_event.set()
if not self._closed:
idle_handler = getattr(self, "on_idle", None)
if idle_handler is not None:
log.debug("c %r", idle_handler)
if idle_handler is not None and not self._closed:
log.debug("d")
await idle_handler(events.Idle(self))
log.debug("e")
self._queue_empty_event.set()
async def dispatch_message(
self, message: Message, priority: int = 0
) -> bool | None:
async def dispatch_message(self, message: Message) -> bool | None:
log.debug("dispatch_message %r", message)
if isinstance(message, events.Event):
await self.on_event(message, priority)
await self.on_event(message)
else:
return await self.on_message(message)
return False
async def on_event(self, event: events.Event, priority: int) -> None:
async def on_event(self, event: events.Event) -> None:
method_name = f"on_{event.name}"
dispatch_function: MessageHandler = getattr(self, method_name, None)
log.debug("dispatching to %r", dispatch_function)
if dispatch_function is not None:
await dispatch_function(event)
if event.bubble and self._parent and not event._stop_propagaton:
if event.sender == self._parent:
log.debug("bubbled event abandoned; %r", event)
else:
await self._parent.post_message(event, priority)
elif not self._parent._closed and not self._parent._closing:
await self._parent.post_message(event)
async def on_message(self, message: Message) -> None:
pass
async def post_message(
self,
message: Message,
priority: int | None = None,
) -> bool:
def post_message_no_wait(self, message: Message) -> bool:
if self._closing or self._closed:
return False
if not self.check_message_enabled(message):
return True
event_priority = priority if priority is not None else message.default_priority
item = MessageQueueItem(event_priority, message)
await self._message_queue.put(item)
self._message_queue.put_nowait(message)
return True
async def post_message_from_child(
self, message: Message, priority: int | None = None
) -> None:
await self.post_message(message, priority=priority)
async def post_message(self, message: Message) -> bool:
log.debug("%r post_message 1", self)
if self._closing or self._closed:
return False
log.debug("%r post_message 2", self)
if not self.check_message_enabled(message):
return True
log.debug("%r post_message 3", self)
await self._message_queue.put(message)
log.debug("%r post_message 4", self)
return True
async def emit(self, message: Message, priority: int | None = None) -> bool:
async def post_message_from_child(self, message: Message) -> bool:
if self._closing or self._closed:
return False
return await self.post_message(message)
async def emit(self, message: Message) -> bool:
if self._parent:
await self._parent.post_message_from_child(message, priority=priority)
await self._parent.post_message_from_child(message)
return True
else:
return False

View File

@@ -108,6 +108,7 @@ class LayoutView(View):
raise NoWidget(f"No widget at ${x}, ${y}")
async def on_message(self, message: Message) -> None:
log.debug("on_message %r", repr(message))
if isinstance(message, UpdateMessage):
widget = message.sender
if widget in self._widgets:

View File

@@ -37,8 +37,6 @@ T = TypeVar("T")
class UpdateMessage(Message):
default_priority = 10
def can_batch(self, message: Message) -> bool:
return isinstance(message, UpdateMessage) and message.sender == self.sender
@@ -139,7 +137,6 @@ class Widget(MessagePump):
def render_update(self, x: int, y: int) -> Iterable[Segment]:
width, height = self.size
log.debug("widget size = %r", self.size)
yield from self.line_cache.render(x, y, width, height)
def render(self, console: Console, options: ConsoleOptions) -> RenderableType:
@@ -147,19 +144,19 @@ class Widget(MessagePump):
Align.center(Pretty(self), vertical="middle"), title=self.__class__.__name__
)
async def post_message(self, message: Message, priority: int | None = None) -> bool:
async def post_message(self, message: Message) -> bool:
if not self.check_message_enabled(message):
return True
return await super().post_message(message, priority)
return await super().post_message(message)
async def on_event(self, event: events.Event, priority: int) -> None:
async def on_event(self, event: events.Event) -> None:
if isinstance(event, events.Resize):
new_size = Dimensions(event.width, event.height)
if self.size != new_size:
self.size = new_size
self.require_repaint()
await super().on_event(event, priority)
await super().on_event(event)
async def on_idle(self, event: events.Idle) -> None:
if self.line_cache is None or self.line_cache.dirty:

View File

@@ -51,4 +51,5 @@ class Header(Widget):
return header
async def on_mount(self, event: events.Mount) -> None:
return
self.set_interval(1.0, callback=self.refresh)