diff --git a/src/textual/app.py b/src/textual/app.py index b207bbd02..df38c6091 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +from concurrent.futures import Future +from functools import partial import inspect import io import os @@ -18,6 +20,7 @@ from time import perf_counter from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Generic, Iterable, @@ -206,6 +209,8 @@ class _WriterThread(threading.Thread): CSSPathType = Union[str, PurePath, List[Union[str, PurePath]], None] +CallThreadReturnType = TypeVar("CallThreadReturnType") + @rich.repr.auto class App(Generic[ReturnType], DOMNode): @@ -353,6 +358,8 @@ class App(Generic[ReturnType], DOMNode): else: self.devtools = DevtoolsClient() + self._loop: asyncio.AbstractEventLoop | None = None + self._thread_id: int = 0 self._return_value: ReturnType | None = None self._exit = False @@ -604,6 +611,51 @@ class App(Generic[ReturnType], DOMNode): except Exception as error: self._handle_exception(error) + def call_from_thread( + self, + callback: Callable[..., CallThreadReturnType | Awaitable[CallThreadReturnType]], + *args, + **kwargs, + ) -> CallThreadReturnType: + """Run a callback from another thread. + + Like asyncio apps in general, Textual apps are not thread-safe. If you call methods + or set attributes on Textual objects from a thread, you may get unpredictable results. + + This method will ensure that your code is ran within the correct context. + + Args: + callback (Callable): A callable to run. + *args: Arguments to the callback. + **kwargs: Keyword arguments for the callback. + + Raises: + RuntimeError: If the app isn't running or if this method is called from the same + thread where the app is running. + """ + + if self._loop is None: + raise RuntimeError("App is not running") + + if self._thread_id == threading.get_ident(): + raise RuntimeError( + "The `call_from_thread` method must run in a different thread from the app" + ) + + callback_with_args = partial(callback, *args, **kwargs) + + async def run_callback() -> CallThreadReturnType: + """Run the callback, set the result or error on the future.""" + self._set_active() + return await invoke(callback_with_args) + + # Post the message to the main loop + future: Future[Any] = asyncio.run_coroutine_threadsafe( + run_callback(), loop=self._loop + ) + result = future.result() + return result + def action_toggle_dark(self) -> None: """Action to toggle dark mode.""" self.dark = not self.dark @@ -874,11 +926,17 @@ class App(Generic[ReturnType], DOMNode): async def run_app() -> None: """Run the app.""" - await self.run_async( - headless=headless, - size=size, - auto_pilot=auto_pilot, - ) + self._loop = asyncio.get_running_loop() + self._thread_id = threading.get_ident() + try: + await self.run_async( + headless=headless, + size=size, + auto_pilot=auto_pilot, + ) + finally: + self._loop = None + self._thread_id = 0 if _ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED: # N.B. This doesn't work with Python<3.10, as we end up with 2 event loops: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 000000000..88a3ccf4b --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,50 @@ +import pytest + +from threading import Thread +from textual.app import App, ComposeResult +from textual.widgets import TextLog + + +def test_call_from_thread_app_not_running(): + app = App() + + # Should fail if app is not running + with pytest.raises(RuntimeError): + app.call_from_thread(print) + + +def test_call_from_thread(): + class BackgroundThread(Thread): + """A background thread which will modify app in some way.""" + + def __init__(self, app: App) -> None: + self.app = app + super().__init__() + + def run(self) -> None: + def write_stuff(text: str) -> None: + """Write stuff to a widget.""" + self.app.query_one(TextLog).write(text) + + self.app.call_from_thread(write_stuff, "Hello") + # Exit the app with a code we can assert + self.app.call_from_thread(self.app.exit, 123) + + class ThreadTestApp(App): + """Trivial app with a single widget.""" + + def compose(self) -> ComposeResult: + yield TextLog() + + def on_ready(self) -> None: + """Launch a thread which will modify the app.""" + try: + self.call_from_thread(print) + except RuntimeError as error: + self._runtime_error = error + BackgroundThread(self).start() + + app = ThreadTestApp() + result = app.run(headless=True, size=(80, 24)) + assert isinstance(app._runtime_error, RuntimeError) + assert result == 123