diff --git a/docs/examples/widgets/custom.py b/docs/examples/widgets/custom.py index b1fdbe8ec..7be00ed8a 100644 --- a/docs/examples/widgets/custom.py +++ b/docs/examples/widgets/custom.py @@ -1,6 +1,6 @@ +from rich.console import RenderableType from rich.panel import Panel -from textual import events from textual.app import App from textual.reactive import Reactive from textual.widget import Widget @@ -8,22 +8,22 @@ from textual.widget import Widget class Hover(Widget): - mouse_over: Reactive[bool] = Reactive(False) + mouse_over = Reactive(False) - def render(self) -> Panel: + def render(self) -> RenderableType: return Panel("Hello [b]World[/b]", style=("on red" if self.mouse_over else "")) - async def on_enter(self, event: events.Enter) -> None: + def on_enter(self) -> None: self.mouse_over = True - async def on_leave(self, event: events.Leave) -> None: + def on_leave(self) -> None: self.mouse_over = False class HoverApp(App): """Demonstrates smooth animation""" - async def on_mount(self, event: events.Mount) -> None: + async def on_mount(self) -> None: """Build layout here.""" hovers = (Hover() for _ in range(10)) diff --git a/src/textual/_callback.py b/src/textual/_callback.py new file mode 100644 index 000000000..de001e09c --- /dev/null +++ b/src/textual/_callback.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from functools import lru_cache + +from inspect import signature, isawaitable +from typing import Any, Callable + + +@lru_cache(maxsize=2048) +def count_parameters(func: Callable) -> int: + """Count the number of parameters in a callable""" + return len(signature(func).parameters) + + +async def invoke(callback: Callable, *params: object) -> Any: + """Invoke a callback with an arbitrary number of parameters. + + Args: + callback (Callable): [description] + + Returns: + Any: [description] + """ + parameter_count = count_parameters(callback) + + result = callback(*params[:parameter_count]) + if isawaitable(result): + await result diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 5653885c8..16a06669b 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -12,6 +12,7 @@ from rich.traceback import Traceback from . import events from . import log from ._timer import Timer, TimerCallback +from ._callback import invoke from ._context import active_app from .message import Message from .reactive import Reactive @@ -241,7 +242,7 @@ class MessagePump: for method in self._get_dispatch_methods(f"on_{event.name}", event): log(event, ">>>", self, verbosity=event.verbosity) - await method(event) + await invoke(method, event) if event.bubble and self._parent and not event._stop_propagation: if event.sender != self._parent and self.is_parent_active: @@ -308,9 +309,7 @@ class MessagePump: event.stop() if event.callback is not None: try: - callback_result = event.callback() - if inspect.isawaitable(callback_result): - await callback_result + await invoke(event.callback) except Exception as error: raise CallbackError( f"unable to run callback {event.callback!r}; {error}" diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 7ae603207..e3a9cf6ba 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -16,6 +16,7 @@ from typing import ( from . import log from . import events +from ._callback import count_parameters from ._types import MessageTarget if TYPE_CHECKING: @@ -28,10 +29,6 @@ if TYPE_CHECKING: ReactiveType = TypeVar("ReactiveType") -def count_params(func: Callable) -> int: - return len(inspect.signature(func).parameters) - - class Reactive(Generic[ReactiveType]): """Reactive descriptor.""" @@ -92,7 +89,7 @@ class Reactive(Generic[ReactiveType]): obj: Reactable, watch_function: Callable, old_value: Any, value: Any ) -> None: _rich_traceback_guard = True - if count_params(watch_function) == 2: + if count_parameters(watch_function) == 2: await watch_function(old_value, value) else: await watch_function(value)