diff --git a/src/textual/app.py b/src/textual/app.py index 1afe30322..5edaf9e40 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -1802,11 +1802,19 @@ class App(Generic[ReturnType], DOMNode): event.stop() await self.screen.post_message(event) - async def _on_remove(self, event: events.Remove) -> None: - """Handle a remove event. + def _detach_from_dom(self, widgets: list[Widget]) -> list[Widget]: + """Detach a list of widgets from the DOM. Args: - event (events.Remove): The remove event. + widgets (list[Widget]): The list of widgets to detach from the DOM. + + Returns: + list[Widget]: The list of widgets that should be pruned. + + Note: + A side-effect of calling this function is that each parent of + each affected widget will be made to forget about the affected + child. """ # We've been given a list of widgets to remove, but removing those @@ -1815,7 +1823,7 @@ class App(Generic[ReturnType], DOMNode): # be in the DOM by the time we've finished. Note that, at this # point, it's entirely possible that there will be duplicates. everything_to_remove: list[Widget] = [] - for widget in event.widgets: + for widget in widgets: everything_to_remove.extend( widget.walk_children( Widget, with_self=True, method="depth", reverse=True @@ -1838,11 +1846,9 @@ class App(Generic[ReturnType], DOMNode): # In other words: find the smallest set of ancestors in the DOM that # will remove the widgets requested for removal, and also ensure # that all knock-on effects happen too. - request_remove = set(event.widgets) + request_remove = set(widgets) pruned_remove = [ - widget - for widget in event.widgets - if request_remove.isdisjoint(widget.ancestors) + widget for widget in widgets if request_remove.isdisjoint(widget.ancestors) ] # Now that we know that minimal set of widgets, we go through them @@ -1852,14 +1858,26 @@ class App(Generic[ReturnType], DOMNode): if widget.parent is not None: widget.parent.children._remove(widget) - # Having done that, it's now safe for us to start the process of - # winding down all of the affected widgets. We do that by pruning - # just the roots of each affected branch, and letting the normal - # prune process take care of all the offspring. - for widget in pruned_remove: - await self._prune_node(widget) + # Return the list of widgets that should end up being sent off in a + # prune event. + return pruned_remove - # And finally, redraw all the things! + async def _on_prune(self, event: events.Prune) -> None: + """Handle a prune event. + + Args: + event (events.Prune): The prune event. + """ + + try: + # Prune all the widgets. + for widget in event.widgets: + await self._prune_node(widget) + finally: + # Finally, flag that we're done. + event.finished_flag.set() + + # Flag that the layout needs refreshing. self.refresh(layout=True) def _walk_children(self, root: Widget) -> Iterable[list[Widget]]: diff --git a/src/textual/await_remove.py b/src/textual/await_remove.py new file mode 100644 index 000000000..d93c21cf8 --- /dev/null +++ b/src/textual/await_remove.py @@ -0,0 +1,13 @@ +from asyncio import Event +from typing import Generator + + +class AwaitRemove: + def __init__(self, finished_flag: Event) -> None: + self.finished_flag = finished_flag + + def __await__(self) -> Generator[None, None, None]: + async def await_prune() -> None: + await self.finished_flag.wait() + + return await_prune().__await__() diff --git a/src/textual/css/query.py b/src/textual/css/query.py index 007daf507..50119411a 100644 --- a/src/textual/css/query.py +++ b/src/textual/css/query.py @@ -17,11 +17,13 @@ a method which evaluates the query, such as first() and last(). from __future__ import annotations from typing import cast, Generic, TYPE_CHECKING, Iterator, TypeVar, overload +import asyncio import rich.repr from .. import events from .._context import active_app +from ..await_remove import AwaitRemove from .errors import DeclarationError, TokenError from .match import match from .model import SelectorSet @@ -348,11 +350,18 @@ class DOMQuery(Generic[QueryType]): node.toggle_class(*class_names) return self - def remove(self) -> DOMQuery[QueryType]: + def remove(self) -> AwaitRemove: """Remove matched nodes from the DOM""" + prune_finished_event = asyncio.Event() app = active_app.get() - app.post_message_no_wait(events.Remove(app, widgets=list(self))) - return self + app.post_message_no_wait( + events.Prune( + app, + widgets=app._detach_from_dom(list(self)), + finished_flag=prune_finished_event, + ) + ) + return AwaitRemove(prune_finished_event) def set_styles( self, css: str | None = None, **update_styles diff --git a/src/textual/events.py b/src/textual/events.py index d34b206ad..a2cd27190 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from .timer import Timer as TimerClass from .timer import TimerCallback from .widget import Widget + import asyncio @rich.repr.auto @@ -126,12 +127,21 @@ class Unmount(Mount, bubble=False, verbose=False): """Sent when a widget is unmounted and may not longer receive messages.""" -class Remove(Event, bubble=False): - """Sent to the app to ask it to remove one or more widgets from the DOM.""" +class Prune(Event, bubble=False): + """Sent to the app to ask it to prune one or more widgets from the DOM.""" - def __init__(self, sender: MessageTarget, widgets: list[Widget]) -> None: - self.widgets = widgets + def __init__( + self, sender: MessageTarget, widgets: list[Widget], finished_flag: asyncio.Event + ) -> None: + """Initialise the event. + + Args: + widgets (list[Widgets]): The list of widgets to prune. + finished_flag (asyncio.Event): An asyncio Event to that will be flagged when the prune is done. + """ super().__init__(sender) + self.finished_flag = finished_flag + self.widgets = widgets class Show(Event, bubble=False): diff --git a/src/textual/widget.py b/src/textual/widget.py index b1eb32cb9..214a3d30d 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -1,6 +1,6 @@ from __future__ import annotations -from asyncio import Lock, wait, create_task +from asyncio import Lock, wait, create_task, Event as AsyncEvent from fractions import Fraction from itertools import islice from operator import attrgetter @@ -49,6 +49,7 @@ from .message import Message from .messages import CallbackType from .reactive import Reactive from .render import measure +from .await_remove import AwaitRemove if TYPE_CHECKING: from .app import App, ComposeResult @@ -1990,9 +1991,17 @@ class Widget(DOMNode): self.check_idle() - def remove(self) -> None: + def remove(self) -> AwaitRemove: """Remove the Widget from the DOM (effectively deleting it)""" - self.app.post_message_no_wait(events.Remove(self, widgets=[self])) + prune_finished_event = AsyncEvent() + self.app.post_message_no_wait( + events.Prune( + self, + widgets=self.app._detach_from_dom([self]), + finished_flag=prune_finished_event, + ) + ) + return AwaitRemove(prune_finished_event) def render(self) -> RenderableType: """Get renderable for widget. diff --git a/tests/test_widget_removing.py b/tests/test_widget_removing.py index 822695073..5391fe7a0 100644 --- a/tests/test_widget_removing.py +++ b/tests/test_widget_removing.py @@ -4,24 +4,12 @@ from textual.widget import Widget from textual.widgets import Static, Button from textual.containers import Container -async def await_remove_standin(): - """Standin function for awaiting removal. - - These tests are being written so that we can go on and make remove - awaitable, but it would be good to have some tests in place *before* we - make that change, but the tests need to await remove to be useful tests. - So to get around that bootstrap issue, we just use this function as a - standin until we can swap over. - """ - await asyncio.sleep(0) # Until we can await remove. - async def test_remove_single_widget(): """It should be possible to the only widget on a screen.""" async with App().run_test() as pilot: await pilot.app.mount(Static()) assert len(pilot.app.screen.children) == 1 - pilot.app.query_one(Static).remove() - await await_remove_standin() + await pilot.app.query_one(Static).remove() assert len(pilot.app.screen.children) == 0 async def test_many_remove_all_widgets(): @@ -29,8 +17,7 @@ async def test_many_remove_all_widgets(): async with App().run_test() as pilot: await pilot.app.mount(*[Static() for _ in range(1000)]) assert len(pilot.app.screen.children) == 1000 - pilot.app.query(Static).remove() - await await_remove_standin() + await pilot.app.query(Static).remove() assert len(pilot.app.screen.children) == 0 async def test_many_remove_some_widgets(): @@ -38,8 +25,7 @@ async def test_many_remove_some_widgets(): async with App().run_test() as pilot: await pilot.app.mount(*[Static(id=f"is-{n%2}") for n in range(1000)]) assert len(pilot.app.screen.children) == 1000 - pilot.app.query("#is-0").remove() - await await_remove_standin() + await pilot.app.query("#is-0").remove() assert len(pilot.app.screen.children) == 500 async def test_remove_branch(): @@ -71,8 +57,7 @@ async def test_remove_branch(): ), ) assert len(pilot.app.screen.walk_children(with_self=False)) == 13 - pilot.app.screen.children[0].remove() - await await_remove_standin() + await pilot.app.screen.children[0].remove() assert len(pilot.app.screen.walk_children(with_self=False)) == 7 async def test_remove_overlap(): @@ -104,8 +89,7 @@ async def test_remove_overlap(): ), ) assert len(pilot.app.screen.walk_children(with_self=False)) == 13 - pilot.app.query(Container).remove() - await await_remove_standin() + await pilot.app.query(Container).remove() assert len(pilot.app.screen.walk_children(with_self=False)) == 1 async def test_remove_move_focus(): @@ -119,8 +103,7 @@ async def test_remove_move_focus(): await pilot.press( "tab" ) assert pilot.app.focused is not None assert pilot.app.focused == buttons[0] - pilot.app.screen.children[0].remove() - await await_remove_standin() + await pilot.app.screen.children[0].remove() assert len(pilot.app.screen.children) == 1 assert len(pilot.app.screen.walk_children(with_self=False)) == 6 assert pilot.app.focused is not None