prevent implementation

This commit is contained in:
Will McGugan
2023-02-23 13:49:07 +00:00
parent e3cbaa8dca
commit 7d99d168ff
5 changed files with 75 additions and 4 deletions

View File

@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Added horizontal rule to Markdown https://github.com/Textualize/textual/pull/1832 - Added horizontal rule to Markdown https://github.com/Textualize/textual/pull/1832
- Added `Widget.disabled` https://github.com/Textualize/textual/pull/1785 - Added `Widget.disabled` https://github.com/Textualize/textual/pull/1785
- Added `DOMNode.notify_style_update` to replace `messages.StylesUpdated` message https://github.com/Textualize/textual/pull/1861 - Added `DOMNode.notify_style_update` to replace `messages.StylesUpdated` message https://github.com/Textualize/textual/pull/1861
- Added `MessagePump.prevent` context manager to temporarily suppress a given message type
### Changed ### Changed

View File

@@ -29,9 +29,13 @@ class Event(Message):
@rich.repr.auto @rich.repr.auto
class Callback(Event, bubble=False, verbose=True): class Callback(Event, bubble=False, verbose=True):
def __init__( def __init__(
self, sender: MessageTarget, callback: Callable[[], Awaitable[None]] self,
sender: MessageTarget,
callback: Callable[[], Awaitable[None]],
prevent: set[type[Message]] | None = None,
) -> None: ) -> None:
self.callback = callback self.callback = callback
self.prevent = frozenset(prevent) if prevent else None
super().__init__(sender) super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:

View File

@@ -10,8 +10,9 @@ from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
from asyncio import CancelledError, Queue, QueueEmpty, Task from asyncio import CancelledError, Queue, QueueEmpty, Task
from contextlib import contextmanager
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable
from weakref import WeakSet from weakref import WeakSet
from . import Logger, events, log, messages from . import Logger, events, log, messages
@@ -19,6 +20,7 @@ from ._asyncio import create_task
from ._callback import invoke from ._callback import invoke
from ._context import NoActiveAppError, active_app, active_message_pump from ._context import NoActiveAppError, active_app, active_message_pump
from ._time import time from ._time import time
from ._types import CallbackType
from .case import camel_to_snake from .case import camel_to_snake
from .errors import DuplicateKeyHandlers from .errors import DuplicateKeyHandlers
from .events import Event from .events import Event
@@ -77,6 +79,7 @@ class MessagePump(metaclass=MessagePumpMeta):
self._max_idle: float | None = None self._max_idle: float | None = None
self._mounted_event = asyncio.Event() self._mounted_event = asyncio.Event()
self._next_callbacks: list[CallbackType] = [] self._next_callbacks: list[CallbackType] = []
self._prevent_events: list[set[type[Message]]] = []
@property @property
def task(self) -> Task: def task(self) -> Task:
@@ -149,6 +152,9 @@ class MessagePump(metaclass=MessagePumpMeta):
self._parent = None self._parent = None
def check_message_enabled(self, message: Message) -> bool: def check_message_enabled(self, message: Message) -> bool:
message_type = type(message)
if self._prevent_events and message_type in self._prevent_events[-1]:
return False
return type(message) not in self._disabled_messages return type(message) not in self._disabled_messages
def disable_messages(self, *messages: type[Message]) -> None: def disable_messages(self, *messages: type[Message]) -> None:
@@ -527,6 +533,27 @@ class MessagePump(metaclass=MessagePumpMeta):
if self._message_queue.empty(): if self._message_queue.empty():
self.post_message_no_wait(messages.Prompt(sender=self)) self.post_message_no_wait(messages.Prompt(sender=self))
@contextmanager
def prevent(self, *message_types: type[Message]) -> Generator[None, None, None]:
"""A context manager to *temporarily* prevent the given message types from being posted.
Example:
input = self.query_one(Input)
with self.prevent(Input.Changed):
input.value = "foo"
"""
if self._prevent_events:
self._prevent_events.append(self._prevent_events[-1].union(message_types))
else:
self._prevent_events.append(set(message_types))
try:
yield
finally:
self._prevent_events.pop()
async def post_message(self, message: Message) -> bool: async def post_message(self, message: Message) -> bool:
"""Post a message or an event to this message pump. """Post a message or an event to this message pump.
@@ -594,7 +621,11 @@ class MessagePump(metaclass=MessagePumpMeta):
return self.post_message_no_wait(message) return self.post_message_no_wait(message)
async def on_callback(self, event: events.Callback) -> None: async def on_callback(self, event: events.Callback) -> None:
await invoke(event.callback) if event.prevent:
with self.prevent(*event.prevent):
await invoke(event.callback)
else:
await invoke(event.callback)
# TODO: Does dispatch_key belong on message pump? # TODO: Does dispatch_key belong on message pump?
async def dispatch_key(self, event: events.Key) -> bool: async def dispatch_key(self, event: events.Key) -> bool:

View File

@@ -235,7 +235,9 @@ class Reactive(Generic[ReactiveType]):
# Result is awaitable, so we need to await it within an async context # Result is awaitable, so we need to await it within an async context
obj.post_message_no_wait( obj.post_message_no_wait(
events.Callback( events.Callback(
sender=obj, callback=partial(await_watcher, watch_result) sender=obj,
callback=partial(await_watcher, watch_result),
prevent=obj._prevent_events[0] if obj._prevent_events else None,
) )
) )

View File

@@ -1,8 +1,10 @@
import pytest import pytest
from textual.app import App, ComposeResult
from textual.errors import DuplicateKeyHandlers from textual.errors import DuplicateKeyHandlers
from textual.events import Key from textual.events import Key
from textual.widget import Widget from textual.widget import Widget
from textual.widgets import Input
class ValidWidget(Widget): class ValidWidget(Widget):
@@ -54,3 +56,34 @@ async def test_dispatch_key_raises_when_conflicting_handler_aliases():
with pytest.raises(DuplicateKeyHandlers): with pytest.raises(DuplicateKeyHandlers):
await widget.dispatch_key(Key(widget, key="tab", character="\t")) await widget.dispatch_key(Key(widget, key="tab", character="\t"))
assert widget.called_by == widget.key_tab assert widget.called_by == widget.key_tab
class PreventTestApp(App):
def __init__(self) -> None:
self.input_changed_events = []
super().__init__()
def compose(self) -> ComposeResult:
yield Input()
def on_input_changed(self, event: Input.Changed) -> None:
self.input_changed_events.append(event)
async def test_prevent() -> None:
app = PreventTestApp()
async with app.run_test() as pilot:
assert not app.input_changed_events
input = app.query_one(Input)
input.value = "foo"
await pilot.pause()
assert len(app.input_changed_events) == 1
assert app.input_changed_events[0].value == "foo"
with input.prevent(Input.Changed):
input.value = "bar"
await pilot.pause()
assert len(app.input_changed_events) == 1
assert app.input_changed_events[0].value == "foo"