prevent implementation

This commit is contained in:
Will McGugan
2023-02-23 13:49:07 +00:00
parent e3cbaa8dca
commit 7d99d168ff
5 changed files with 75 additions and 4 deletions

View File

@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Added horizontal rule to Markdown https://github.com/Textualize/textual/pull/1832
- Added `Widget.disabled` https://github.com/Textualize/textual/pull/1785
- Added `DOMNode.notify_style_update` to replace `messages.StylesUpdated` message https://github.com/Textualize/textual/pull/1861
- Added `MessagePump.prevent` context manager to temporarily suppress a given message type
### Changed

View File

@@ -29,9 +29,13 @@ class Event(Message):
@rich.repr.auto
class Callback(Event, bubble=False, verbose=True):
def __init__(
self, sender: MessageTarget, callback: Callable[[], Awaitable[None]]
self,
sender: MessageTarget,
callback: Callable[[], Awaitable[None]],
prevent: set[type[Message]] | None = None,
) -> None:
self.callback = callback
self.prevent = frozenset(prevent) if prevent else None
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:

View File

@@ -10,8 +10,9 @@ from __future__ import annotations
import asyncio
import inspect
from asyncio import CancelledError, Queue, QueueEmpty, Task
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable
from weakref import WeakSet
from . import Logger, events, log, messages
@@ -19,6 +20,7 @@ from ._asyncio import create_task
from ._callback import invoke
from ._context import NoActiveAppError, active_app, active_message_pump
from ._time import time
from ._types import CallbackType
from .case import camel_to_snake
from .errors import DuplicateKeyHandlers
from .events import Event
@@ -77,6 +79,7 @@ class MessagePump(metaclass=MessagePumpMeta):
self._max_idle: float | None = None
self._mounted_event = asyncio.Event()
self._next_callbacks: list[CallbackType] = []
self._prevent_events: list[set[type[Message]]] = []
@property
def task(self) -> Task:
@@ -149,6 +152,9 @@ class MessagePump(metaclass=MessagePumpMeta):
self._parent = None
def check_message_enabled(self, message: Message) -> bool:
message_type = type(message)
if self._prevent_events and message_type in self._prevent_events[-1]:
return False
return type(message) not in self._disabled_messages
def disable_messages(self, *messages: type[Message]) -> None:
@@ -527,6 +533,27 @@ class MessagePump(metaclass=MessagePumpMeta):
if self._message_queue.empty():
self.post_message_no_wait(messages.Prompt(sender=self))
@contextmanager
def prevent(self, *message_types: type[Message]) -> Generator[None, None, None]:
"""A context manager to *temporarily* prevent the given message types from being posted.
Example:
input = self.query_one(Input)
with self.prevent(Input.Changed):
input.value = "foo"
"""
if self._prevent_events:
self._prevent_events.append(self._prevent_events[-1].union(message_types))
else:
self._prevent_events.append(set(message_types))
try:
yield
finally:
self._prevent_events.pop()
async def post_message(self, message: Message) -> bool:
"""Post a message or an event to this message pump.
@@ -594,6 +621,10 @@ class MessagePump(metaclass=MessagePumpMeta):
return self.post_message_no_wait(message)
async def on_callback(self, event: events.Callback) -> None:
if event.prevent:
with self.prevent(*event.prevent):
await invoke(event.callback)
else:
await invoke(event.callback)
# TODO: Does dispatch_key belong on message pump?

View File

@@ -235,7 +235,9 @@ class Reactive(Generic[ReactiveType]):
# Result is awaitable, so we need to await it within an async context
obj.post_message_no_wait(
events.Callback(
sender=obj, callback=partial(await_watcher, watch_result)
sender=obj,
callback=partial(await_watcher, watch_result),
prevent=obj._prevent_events[0] if obj._prevent_events else None,
)
)

View File

@@ -1,8 +1,10 @@
import pytest
from textual.app import App, ComposeResult
from textual.errors import DuplicateKeyHandlers
from textual.events import Key
from textual.widget import Widget
from textual.widgets import Input
class ValidWidget(Widget):
@@ -54,3 +56,34 @@ async def test_dispatch_key_raises_when_conflicting_handler_aliases():
with pytest.raises(DuplicateKeyHandlers):
await widget.dispatch_key(Key(widget, key="tab", character="\t"))
assert widget.called_by == widget.key_tab
class PreventTestApp(App):
def __init__(self) -> None:
self.input_changed_events = []
super().__init__()
def compose(self) -> ComposeResult:
yield Input()
def on_input_changed(self, event: Input.Changed) -> None:
self.input_changed_events.append(event)
async def test_prevent() -> None:
app = PreventTestApp()
async with app.run_test() as pilot:
assert not app.input_changed_events
input = app.query_one(Input)
input.value = "foo"
await pilot.pause()
assert len(app.input_changed_events) == 1
assert app.input_changed_events[0].value == "foo"
with input.prevent(Input.Changed):
input.value = "bar"
await pilot.pause()
assert len(app.input_changed_events) == 1
assert app.input_changed_events[0].value == "foo"