Merge pull request #1157 from davep/awaitable-remove

Awaitable remove
This commit is contained in:
Will McGugan
2022-11-11 18:06:41 +00:00
committed by GitHub
7 changed files with 156 additions and 50 deletions

View File

@@ -17,6 +17,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
`DOMNode.ancestors`.
- Improved the speed of `DOMQuery.remove`.
- Added DataTable.clear
- It is now possible to `await` a `Widget.remove`.
https://github.com/Textualize/textual/issues/1094
- It is now possible to `await` a `DOMQuery.remove`. Note that this changes
the return value of `DOMQuery.remove`, which uses to return `self`.
https://github.com/Textualize/textual/issues/1094
### Changed

View File

@@ -1810,11 +1810,19 @@ class App(Generic[ReturnType], DOMNode):
event.stop()
await self.screen.post_message(event)
async def _on_remove(self, event: events.Remove) -> None:
"""Handle a remove event.
def _detach_from_dom(self, widgets: list[Widget]) -> list[Widget]:
"""Detach a list of widgets from the DOM.
Args:
event (events.Remove): The remove event.
widgets (list[Widget]): The list of widgets to detach from the DOM.
Returns:
list[Widget]: The list of widgets that should be pruned.
Note:
A side-effect of calling this function is that each parent of
each affected widget will be made to forget about the affected
child.
"""
# We've been given a list of widgets to remove, but removing those
@@ -1823,7 +1831,7 @@ class App(Generic[ReturnType], DOMNode):
# be in the DOM by the time we've finished. Note that, at this
# point, it's entirely possible that there will be duplicates.
everything_to_remove: list[Widget] = []
for widget in event.widgets:
for widget in widgets:
everything_to_remove.extend(
widget.walk_children(
Widget, with_self=True, method="depth", reverse=True
@@ -1846,11 +1854,9 @@ class App(Generic[ReturnType], DOMNode):
# In other words: find the smallest set of ancestors in the DOM that
# will remove the widgets requested for removal, and also ensure
# that all knock-on effects happen too.
request_remove = set(event.widgets)
request_remove = set(widgets)
pruned_remove = [
widget
for widget in event.widgets
if request_remove.isdisjoint(widget.ancestors)
widget for widget in widgets if request_remove.isdisjoint(widget.ancestors)
]
# Now that we know that minimal set of widgets, we go through them
@@ -1860,14 +1866,26 @@ class App(Generic[ReturnType], DOMNode):
if widget.parent is not None:
widget.parent.children._remove(widget)
# Having done that, it's now safe for us to start the process of
# winding down all of the affected widgets. We do that by pruning
# just the roots of each affected branch, and letting the normal
# prune process take care of all the offspring.
for widget in pruned_remove:
await self._prune_node(widget)
# Return the list of widgets that should end up being sent off in a
# prune event.
return pruned_remove
# And finally, redraw all the things!
async def _on_prune(self, event: events.Prune) -> None:
"""Handle a prune event.
Args:
event (events.Prune): The prune event.
"""
try:
# Prune all the widgets.
for widget in event.widgets:
await self._prune_node(widget)
finally:
# Finally, flag that we're done.
event.finished_flag.set()
# Flag that the layout needs refreshing.
self.refresh(layout=True)
def _walk_children(self, root: Widget) -> Iterable[list[Widget]]:

View File

@@ -0,0 +1,23 @@
"""Provides the type of an awaitable remove."""
from asyncio import Event
from typing import Generator
class AwaitRemove:
"""An awaitable returned by App.remove and DOMQuery.remove."""
def __init__(self, finished_flag: Event) -> None:
"""Initialise the instance of ``AwaitRemove``.
Args:
finished_flag (asyncio.Event): The asyncio event to wait on.
"""
self.finished_flag = finished_flag
def __await__(self) -> Generator[None, None, None]:
async def await_prune() -> None:
"""Wait for the prune operation to finish."""
await self.finished_flag.wait()
return await_prune().__await__()

View File

@@ -17,11 +17,13 @@ a method which evaluates the query, such as first() and last().
from __future__ import annotations
from typing import cast, Generic, TYPE_CHECKING, Iterator, TypeVar, overload
import asyncio
import rich.repr
from .. import events
from .._context import active_app
from ..await_remove import AwaitRemove
from .errors import DeclarationError, TokenError
from .match import match
from .model import SelectorSet
@@ -348,11 +350,22 @@ class DOMQuery(Generic[QueryType]):
node.toggle_class(*class_names)
return self
def remove(self) -> DOMQuery[QueryType]:
"""Remove matched nodes from the DOM"""
def remove(self) -> AwaitRemove:
"""Remove matched nodes from the DOM.
Returns:
AwaitRemove: An awaitable object that waits for the widgets to be removed.
"""
prune_finished_event = asyncio.Event()
app = active_app.get()
app.post_message_no_wait(events.Remove(app, widgets=list(self)))
return self
app.post_message_no_wait(
events.Prune(
app,
widgets=app._detach_from_dom(list(self)),
finished_flag=prune_finished_event,
)
)
return AwaitRemove(prune_finished_event)
def set_styles(
self, css: str | None = None, **update_styles

View File

@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from .timer import Timer as TimerClass
from .timer import TimerCallback
from .widget import Widget
import asyncio
@rich.repr.auto
@@ -126,12 +127,26 @@ class Unmount(Mount, bubble=False, verbose=False):
"""Sent when a widget is unmounted and may not longer receive messages."""
class Remove(Event, bubble=False):
"""Sent to the app to ask it to remove one or more widgets from the DOM."""
class Prune(Event, bubble=False):
"""Sent to the app to ask it to prune one or more widgets from the DOM.
def __init__(self, sender: MessageTarget, widgets: list[Widget]) -> None:
self.widgets = widgets
Attributes:
widgets (list[Widgets]): The list of widgets to prune.
finished_flag (asyncio.Event): An asyncio Event to that will be flagged when the prune is done.
"""
def __init__(
self, sender: MessageTarget, widgets: list[Widget], finished_flag: asyncio.Event
) -> None:
"""Initialise the event.
Args:
widgets (list[Widgets]): The list of widgets to prune.
finished_flag (asyncio.Event): An asyncio Event to that will be flagged when the prune is done.
"""
super().__init__(sender)
self.finished_flag = finished_flag
self.widgets = widgets
class Show(Event, bubble=False):

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from asyncio import Lock, wait, create_task
from asyncio import Lock, wait, create_task, Event as AsyncEvent
from fractions import Fraction
from itertools import islice
from operator import attrgetter
@@ -49,6 +49,7 @@ from .message import Message
from .messages import CallbackType
from .reactive import Reactive
from .render import measure
from .await_remove import AwaitRemove
if TYPE_CHECKING:
from .app import App, ComposeResult
@@ -1990,9 +1991,21 @@ class Widget(DOMNode):
self.check_idle()
def remove(self) -> None:
"""Remove the Widget from the DOM (effectively deleting it)"""
self.app.post_message_no_wait(events.Remove(self, widgets=[self]))
def remove(self) -> AwaitRemove:
"""Remove the Widget from the DOM (effectively deleting it)
Returns:
AwaitRemove: An awaitable object that waits for the widget to be removed.
"""
prune_finished_event = AsyncEvent()
self.app.post_message_no_wait(
events.Prune(
self,
widgets=self.app._detach_from_dom([self]),
finished_flag=prune_finished_event,
)
)
return AwaitRemove(prune_finished_event)
def render(self) -> RenderableType:
"""Get renderable for widget.

View File

@@ -4,24 +4,12 @@ from textual.widget import Widget
from textual.widgets import Static, Button
from textual.containers import Container
async def await_remove_standin():
"""Standin function for awaiting removal.
These tests are being written so that we can go on and make remove
awaitable, but it would be good to have some tests in place *before* we
make that change, but the tests need to await remove to be useful tests.
So to get around that bootstrap issue, we just use this function as a
standin until we can swap over.
"""
await asyncio.sleep(0) # Until we can await remove.
async def test_remove_single_widget():
"""It should be possible to the only widget on a screen."""
async with App().run_test() as pilot:
await pilot.app.mount(Static())
assert len(pilot.app.screen.children) == 1
pilot.app.query_one(Static).remove()
await await_remove_standin()
await pilot.app.query_one(Static).remove()
assert len(pilot.app.screen.children) == 0
async def test_many_remove_all_widgets():
@@ -29,8 +17,7 @@ async def test_many_remove_all_widgets():
async with App().run_test() as pilot:
await pilot.app.mount(*[Static() for _ in range(1000)])
assert len(pilot.app.screen.children) == 1000
pilot.app.query(Static).remove()
await await_remove_standin()
await pilot.app.query(Static).remove()
assert len(pilot.app.screen.children) == 0
async def test_many_remove_some_widgets():
@@ -38,8 +25,7 @@ async def test_many_remove_some_widgets():
async with App().run_test() as pilot:
await pilot.app.mount(*[Static(id=f"is-{n%2}") for n in range(1000)])
assert len(pilot.app.screen.children) == 1000
pilot.app.query("#is-0").remove()
await await_remove_standin()
await pilot.app.query("#is-0").remove()
assert len(pilot.app.screen.children) == 500
async def test_remove_branch():
@@ -71,8 +57,7 @@ async def test_remove_branch():
),
)
assert len(pilot.app.screen.walk_children(with_self=False)) == 13
pilot.app.screen.children[0].remove()
await await_remove_standin()
await pilot.app.screen.children[0].remove()
assert len(pilot.app.screen.walk_children(with_self=False)) == 7
async def test_remove_overlap():
@@ -104,8 +89,7 @@ async def test_remove_overlap():
),
)
assert len(pilot.app.screen.walk_children(with_self=False)) == 13
pilot.app.query(Container).remove()
await await_remove_standin()
await pilot.app.query(Container).remove()
assert len(pilot.app.screen.walk_children(with_self=False)) == 1
async def test_remove_move_focus():
@@ -119,9 +103,44 @@ async def test_remove_move_focus():
await pilot.press( "tab" )
assert pilot.app.focused is not None
assert pilot.app.focused == buttons[0]
pilot.app.screen.children[0].remove()
await await_remove_standin()
await pilot.app.screen.children[0].remove()
assert len(pilot.app.screen.children) == 1
assert len(pilot.app.screen.walk_children(with_self=False)) == 6
assert pilot.app.focused is not None
assert pilot.app.focused == buttons[9]
async def test_widget_remove_order():
"""A Widget.remove of a top-level widget should cause bottom-first removal."""
removals: list[str] = []
class Removable(Container):
def on_unmount( self, _ ):
removals.append(self.id if self.id is not None else "unknown")
async with App().run_test() as pilot:
await pilot.app.mount(
Removable(Removable(Removable(id="grandchild"), id="child"), id="parent")
)
assert len(pilot.app.screen.walk_children(with_self=False)) == 3
await pilot.app.screen.children[0].remove()
assert len(pilot.app.screen.walk_children(with_self=False)) == 0
assert removals == ["grandchild", "child", "parent"]
async def test_query_remove_order():
"""A DOMQuery.remove of a top-level widget should cause bottom-first removal."""
removals: list[str] = []
class Removable(Container):
def on_unmount( self, _ ):
removals.append(self.id if self.id is not None else "unknown")
async with App().run_test() as pilot:
await pilot.app.mount(
Removable(Removable(Removable(id="grandchild"), id="child"), id="parent")
)
assert len(pilot.app.screen.walk_children(with_self=False)) == 3
await pilot.app.query(Removable).remove()
assert len(pilot.app.screen.walk_children(with_self=False)) == 0
assert removals == ["grandchild", "child", "parent"]