From 7d99d168ff907955bda8c1da609ebc9093f1a188 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 23 Feb 2023 13:49:07 +0000 Subject: [PATCH] prevent implementation --- CHANGELOG.md | 1 + src/textual/events.py | 6 +++++- src/textual/message_pump.py | 35 +++++++++++++++++++++++++++++++++-- src/textual/reactive.py | 4 +++- tests/test_message_pump.py | 33 +++++++++++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 042796b7b..7e9eb0a49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added horizontal rule to Markdown https://github.com/Textualize/textual/pull/1832 - Added `Widget.disabled` https://github.com/Textualize/textual/pull/1785 - Added `DOMNode.notify_style_update` to replace `messages.StylesUpdated` message https://github.com/Textualize/textual/pull/1861 +- Added `MessagePump.prevent` context manager to temporarily suppress a given message type ### Changed diff --git a/src/textual/events.py b/src/textual/events.py index 9914c03d5..34f302fe2 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -29,9 +29,13 @@ class Event(Message): @rich.repr.auto class Callback(Event, bubble=False, verbose=True): def __init__( - self, sender: MessageTarget, callback: Callable[[], Awaitable[None]] + self, + sender: MessageTarget, + callback: Callable[[], Awaitable[None]], + prevent: set[type[Message]] | None = None, ) -> None: self.callback = callback + self.prevent = frozenset(prevent) if prevent else None super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index d17bd8249..4391103cc 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -10,8 +10,9 @@ from __future__ import annotations import asyncio import inspect from asyncio import CancelledError, Queue, QueueEmpty, Task +from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable from weakref import WeakSet from . import Logger, events, log, messages @@ -19,6 +20,7 @@ from ._asyncio import create_task from ._callback import invoke from ._context import NoActiveAppError, active_app, active_message_pump from ._time import time +from ._types import CallbackType from .case import camel_to_snake from .errors import DuplicateKeyHandlers from .events import Event @@ -77,6 +79,7 @@ class MessagePump(metaclass=MessagePumpMeta): self._max_idle: float | None = None self._mounted_event = asyncio.Event() self._next_callbacks: list[CallbackType] = [] + self._prevent_events: list[set[type[Message]]] = [] @property def task(self) -> Task: @@ -149,6 +152,9 @@ class MessagePump(metaclass=MessagePumpMeta): self._parent = None def check_message_enabled(self, message: Message) -> bool: + message_type = type(message) + if self._prevent_events and message_type in self._prevent_events[-1]: + return False return type(message) not in self._disabled_messages def disable_messages(self, *messages: type[Message]) -> None: @@ -527,6 +533,27 @@ class MessagePump(metaclass=MessagePumpMeta): if self._message_queue.empty(): self.post_message_no_wait(messages.Prompt(sender=self)) + @contextmanager + def prevent(self, *message_types: type[Message]) -> Generator[None, None, None]: + """A context manager to *temporarily* prevent the given message types from being posted. + + Example: + + input = self.query_one(Input) + with self.prevent(Input.Changed): + input.value = "foo" + + + """ + if self._prevent_events: + self._prevent_events.append(self._prevent_events[-1].union(message_types)) + else: + self._prevent_events.append(set(message_types)) + try: + yield + finally: + self._prevent_events.pop() + async def post_message(self, message: Message) -> bool: """Post a message or an event to this message pump. @@ -594,7 +621,11 @@ class MessagePump(metaclass=MessagePumpMeta): return self.post_message_no_wait(message) async def on_callback(self, event: events.Callback) -> None: - await invoke(event.callback) + if event.prevent: + with self.prevent(*event.prevent): + await invoke(event.callback) + else: + await invoke(event.callback) # TODO: Does dispatch_key belong on message pump? async def dispatch_key(self, event: events.Key) -> bool: diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 0553d076c..140cb0ff1 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -235,7 +235,9 @@ class Reactive(Generic[ReactiveType]): # Result is awaitable, so we need to await it within an async context obj.post_message_no_wait( events.Callback( - sender=obj, callback=partial(await_watcher, watch_result) + sender=obj, + callback=partial(await_watcher, watch_result), + prevent=obj._prevent_events[0] if obj._prevent_events else None, ) ) diff --git a/tests/test_message_pump.py b/tests/test_message_pump.py index ce665f7c9..9667b4cbc 100644 --- a/tests/test_message_pump.py +++ b/tests/test_message_pump.py @@ -1,8 +1,10 @@ import pytest +from textual.app import App, ComposeResult from textual.errors import DuplicateKeyHandlers from textual.events import Key from textual.widget import Widget +from textual.widgets import Input class ValidWidget(Widget): @@ -54,3 +56,34 @@ async def test_dispatch_key_raises_when_conflicting_handler_aliases(): with pytest.raises(DuplicateKeyHandlers): await widget.dispatch_key(Key(widget, key="tab", character="\t")) assert widget.called_by == widget.key_tab + + +class PreventTestApp(App): + def __init__(self) -> None: + self.input_changed_events = [] + super().__init__() + + def compose(self) -> ComposeResult: + yield Input() + + def on_input_changed(self, event: Input.Changed) -> None: + self.input_changed_events.append(event) + + +async def test_prevent() -> None: + app = PreventTestApp() + + async with app.run_test() as pilot: + assert not app.input_changed_events + input = app.query_one(Input) + input.value = "foo" + await pilot.pause() + assert len(app.input_changed_events) == 1 + assert app.input_changed_events[0].value == "foo" + + with input.prevent(Input.Changed): + input.value = "bar" + + await pilot.pause() + assert len(app.input_changed_events) == 1 + assert app.input_changed_events[0].value == "foo"