remove old hack

This commit is contained in:
Will McGugan
2024-09-07 17:52:25 +01:00
parent c581182a1c
commit 03ad3af83f
18 changed files with 712 additions and 682 deletions

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
import weakref
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, Callable, TypeVar
if TYPE_CHECKING:
from .app import App
@@ -19,54 +18,8 @@ ContextVarType = TypeVar("ContextVarType")
DefaultType = TypeVar("DefaultType")
class ContextDefault:
pass
_context_default = ContextDefault()
class TextualContextVar(Generic[ContextVarType]):
"""Like ContextVar but doesn't hold on to references."""
def __init__(self, name: str) -> None:
self._context_var: ContextVar[weakref.ReferenceType[ContextVarType]] = (
ContextVar(name)
)
@overload
def get(self) -> ContextVarType: ...
@overload
def get(self, default: DefaultType) -> ContextVarType | DefaultType: ...
def get(
self, default: DefaultType | ContextDefault = _context_default
) -> ContextVarType | DefaultType:
try:
value_ref = self._context_var.get()
except LookupError:
if isinstance(default, ContextDefault):
raise
return default
value = value_ref()
if value is None:
if isinstance(default, ContextDefault):
raise LookupError(value)
return default
return value
def set(self, value: ContextVarType) -> object:
return self._context_var.set(weakref.ref(value))
def reset(self, token: Token[weakref.ReferenceType[ContextVarType]]) -> None:
self._context_var.reset(token)
active_app: TextualContextVar["App[object]"] = TextualContextVar("active_app")
active_message_pump: TextualContextVar["MessagePump"] = TextualContextVar(
"active_message_pump"
)
active_app: ContextVar["App[Any]"] = ContextVar("active_app")
active_message_pump: ContextVar["MessagePump"] = ContextVar("active_message_pump")
prevent_message_types_stack: ContextVar[list[set[type[Message]]]] = ContextVar(
"prevent_message_types_stack"
@@ -75,7 +28,5 @@ visible_screen_stack: ContextVar[list[Screen[object]]] = ContextVar(
"visible_screen_stack"
)
"""A stack of visible screens (with background alpha < 1), used in the screen render process."""
message_hook: TextualContextVar[Callable[[Message], None]] = TextualContextVar(
"message_hook"
)
message_hook: ContextVar[Callable[[Message], None]] = ContextVar("message_hook")
"""A callable that accepts a message. Used by App.run_test."""

View File

@@ -809,6 +809,17 @@ class App(Generic[ReturnType], DOMNode):
if not self._batch_count:
self.check_idle()
@contextmanager
def _context(self) -> Generator[None, None, None]:
"""Context manager to set ContextVars."""
app_reset_token = active_app.set(self)
message_pump_reset_token = active_message_pump.set(self)
try:
yield
finally:
active_message_pump.reset(message_pump_reset_token)
active_app.reset(app_reset_token)
def animate(
self,
attribute: str,
@@ -1046,10 +1057,6 @@ class App(Generic[ReturnType], DOMNode):
"""
return Screen(id="_default")
def _set_active(self) -> None:
"""Set this app to be the currently active app."""
active_app.set(self)
def compose(self) -> ComposeResult:
"""Yield child widgets for a container.
@@ -1355,8 +1362,8 @@ class App(Generic[ReturnType], DOMNode):
async def run_callback() -> CallThreadReturnType:
"""Run the callback, set the result or error on the future."""
self._set_active()
return await invoke(callback_with_args)
with self._context():
return await invoke(callback_with_args)
# Post the message to the main loop
future: Future[CallThreadReturnType] = asyncio.run_coroutine_threadsafe(
@@ -1667,41 +1674,39 @@ class App(Generic[ReturnType], DOMNode):
app: App to run.
"""
try:
if message_hook is not None:
message_hook_context_var.set(message_hook)
app._loop = asyncio.get_running_loop()
app._thread_id = threading.get_ident()
await app._process_messages(
ready_callback=on_app_ready,
headless=headless,
terminal_size=size,
)
finally:
app_ready_event.set()
with app._context():
try:
if message_hook is not None:
message_hook_context_var.set(message_hook)
app._loop = asyncio.get_running_loop()
app._thread_id = threading.get_ident()
await app._process_messages(
ready_callback=on_app_ready,
headless=headless,
terminal_size=size,
)
finally:
app_ready_event.set()
# Launch the app in the "background"
active_message_pump.set(app)
app_task = create_task(run_app(app), name=f"run_test {app}")
# Wait until the app has performed all startup routines.
await app_ready_event.wait()
# Get the app in an active state.
app._set_active()
# Context manager returns pilot object to manipulate the app
try:
pilot = Pilot(app)
await pilot._wait_for_screen()
yield pilot
finally:
# Shutdown the app cleanly
await app._shutdown()
await app_task
# Re-raise the exception which caused panic so test frameworks are aware
if self._exception:
raise self._exception
with app._context():
# Context manager returns pilot object to manipulate the app
try:
pilot = Pilot(app)
await pilot._wait_for_screen()
yield pilot
finally:
# Shutdown the app cleanly
await app._shutdown()
await app_task
# Re-raise the exception which caused panic so test frameworks are aware
if self._exception:
raise self._exception
async def run_async(
self,
@@ -1751,14 +1756,14 @@ class App(Generic[ReturnType], DOMNode):
async def run_auto_pilot(
auto_pilot: AutopilotCallbackType, pilot: Pilot
) -> None:
try:
await auto_pilot(pilot)
except Exception:
app.exit()
raise
with self._context():
try:
await auto_pilot(pilot)
except Exception:
app.exit()
raise
pilot = Pilot(app)
active_message_pump.set(self)
auto_pilot_task = create_task(
run_auto_pilot(auto_pilot, pilot), name=repr(pilot)
)
@@ -1816,18 +1821,19 @@ class App(Generic[ReturnType], DOMNode):
"""Run the app."""
self._loop = asyncio.get_running_loop()
self._thread_id = threading.get_ident()
try:
await self.run_async(
headless=headless,
inline=inline,
inline_no_clear=inline_no_clear,
mouse=mouse,
size=size,
auto_pilot=auto_pilot,
)
finally:
self._loop = None
self._thread_id = 0
with self._context():
try:
await self.run_async(
headless=headless,
inline=inline,
inline_no_clear=inline_no_clear,
mouse=mouse,
size=size,
auto_pilot=auto_pilot,
)
finally:
self._loop = None
self._thread_id = 0
if _ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED:
# N.B. This doesn't work with Python<3.10, as we end up with 2 event loops:
@@ -2680,6 +2686,16 @@ class App(Generic[ReturnType], DOMNode):
)
return driver
async def _init_devtools(self):
if self.devtools is not None:
from textual_dev.client import DevtoolsConnectionError
try:
await self.devtools.connect()
self.log.system(f"Connected to devtools ( {self.devtools.url} )")
except DevtoolsConnectionError:
self.log.system(f"Couldn't connect to devtools ( {self.devtools.url} )")
async def _process_messages(
self,
ready_callback: CallbackType | None = None,
@@ -2690,54 +2706,44 @@ class App(Generic[ReturnType], DOMNode):
terminal_size: tuple[int, int] | None = None,
message_hook: Callable[[Message], None] | None = None,
) -> None:
self._set_active()
active_message_pump.set(self)
if self.devtools is not None:
from textual_dev.client import DevtoolsConnectionError
async def app_prelude() -> bool:
await self._init_devtools()
self.log.system("---")
self.log.system(loop=asyncio.get_running_loop())
self.log.system(features=self.features)
if constants.LOG_FILE is not None:
_log_path = os.path.abspath(constants.LOG_FILE)
self.log.system(f"Writing logs to {_log_path!r}")
try:
await self.devtools.connect()
self.log.system(f"Connected to devtools ( {self.devtools.url} )")
except DevtoolsConnectionError:
self.log.system(f"Couldn't connect to devtools ( {self.devtools.url} )")
if self.css_path:
self.stylesheet.read_all(self.css_path)
for read_from, css, tie_breaker, scope in self._get_default_css():
self.stylesheet.add_source(
css,
read_from=read_from,
is_default_css=True,
tie_breaker=tie_breaker,
scope=scope,
)
if self.CSS:
try:
app_path = inspect.getfile(self.__class__)
except (TypeError, OSError):
app_path = ""
read_from = (app_path, f"{self.__class__.__name__}.CSS")
self.stylesheet.add_source(
self.CSS, read_from=read_from, is_default_css=False
)
except Exception as error:
self._handle_exception(error)
self._print_error_renderables()
return False
self.log.system("---")
self.log.system(loop=asyncio.get_running_loop())
self.log.system(features=self.features)
if constants.LOG_FILE is not None:
_log_path = os.path.abspath(constants.LOG_FILE)
self.log.system(f"Writing logs to {_log_path!r}")
try:
if self.css_path:
self.stylesheet.read_all(self.css_path)
for read_from, css, tie_breaker, scope in self._get_default_css():
self.stylesheet.add_source(
css,
read_from=read_from,
is_default_css=True,
tie_breaker=tie_breaker,
scope=scope,
)
if self.CSS:
try:
app_path = inspect.getfile(self.__class__)
except (TypeError, OSError):
app_path = ""
read_from = (app_path, f"{self.__class__.__name__}.CSS")
self.stylesheet.add_source(
self.CSS, read_from=read_from, is_default_css=False
)
except Exception as error:
self._handle_exception(error)
self._print_error_renderables()
return
if self.css_monitor:
self.set_interval(0.25, self.css_monitor, name="css monitor")
self.log.system("STARTED", self.css_monitor)
if self.css_monitor:
self.set_interval(0.25, self.css_monitor, name="css monitor")
self.log.system("STARTED", self.css_monitor)
return True
async def run_process_messages():
"""The main message loop, invoke below."""
@@ -2788,42 +2794,45 @@ class App(Generic[ReturnType], DOMNode):
finally:
await Timer._stop_all(self._timers)
self._running = True
try:
load_event = events.Load()
await self._dispatch_message(load_event)
with self._context():
if not await app_prelude():
return
self._running = True
try:
load_event = events.Load()
await self._dispatch_message(load_event)
driver = self._driver = self._build_driver(
headless=headless,
inline=inline,
mouse=mouse,
size=terminal_size,
)
self.log(driver=driver)
driver = self._driver = self._build_driver(
headless=headless,
inline=inline,
mouse=mouse,
size=terminal_size,
)
self.log(driver=driver)
if not self._exit:
driver.start_application_mode()
try:
with redirect_stdout(self._capture_stdout):
with redirect_stderr(self._capture_stderr):
await run_process_messages()
if not self._exit:
driver.start_application_mode()
try:
with redirect_stdout(self._capture_stdout):
with redirect_stderr(self._capture_stderr):
await run_process_messages()
finally:
if hasattr(self, "_watchers"):
self._watchers.clear()
if self._driver.is_inline:
cursor_x, cursor_y = self._previous_cursor_position
self._driver.write(
Control.move(-cursor_x, -cursor_y + 1).segment.text
)
if inline_no_clear and not not self.app._exit_renderables:
console = Console()
console.print(self.screen._compositor)
console.print()
finally:
if hasattr(self, "_watchers"):
self._watchers.clear()
if self._driver.is_inline:
cursor_x, cursor_y = self._previous_cursor_position
self._driver.write(
Control.move(-cursor_x, -cursor_y + 1).segment.text
)
if inline_no_clear and not not self.app._exit_renderables:
console = Console()
console.print(self.screen._compositor)
console.print()
driver.stop_application_mode()
except Exception as error:
self._handle_exception(error)
driver.stop_application_mode()
except Exception as error:
self._handle_exception(error)
async def _pre_process(self) -> bool:
"""Special case for the app, which doesn't need the functionality in MessagePump."""
@@ -3049,7 +3058,6 @@ class App(Generic[ReturnType], DOMNode):
await self._close_all()
await self._close_messages()
await self._dispatch_message(events.Unmount())
if self._driver is not None:

View File

@@ -241,6 +241,7 @@ class Provider(ABC):
"""Wait for initialization."""
if self._init_task is not None:
await self._init_task
self._init_task = None
async def startup(self) -> None:
"""Called after the Provider is initialized, but before any calls to `search`."""

View File

@@ -229,7 +229,7 @@ class MessagePump(metaclass=_MessagePumpMeta):
if node is None:
raise NoActiveAppError()
node = node._parent
active_app.set(node)
return node
@property
@@ -501,26 +501,27 @@ class MessagePump(metaclass=_MessagePumpMeta):
async def _process_messages(self) -> None:
self._running = True
active_message_pump.set(self)
if not await self._pre_process():
self._running = False
return
with self._context():
if not await self._pre_process():
self._running = False
return
try:
await self._process_messages_loop()
except CancelledError:
pass
finally:
self._running = False
try:
if self._timers:
await Timer._stop_all(self._timers)
self._timers.clear()
await self._process_messages_loop()
except CancelledError:
pass
finally:
if hasattr(self, "_watchers"):
self._watchers.clear()
await self._message_loop_exit()
self._running = False
try:
if self._timers:
await Timer._stop_all(self._timers)
self._timers.clear()
if hasattr(self, "_watchers"):
self._watchers.clear()
finally:
await self._message_loop_exit()
self._task = None
async def _message_loop_exit(self) -> None:
"""Called when the message loop has completed."""
@@ -560,6 +561,15 @@ class MessagePump(metaclass=_MessagePumpMeta):
"""Request the message queue to immediately exit."""
self._message_queue.put_nowait(messages.CloseMessages())
@contextmanager
def _context(self) -> Generator[None, None, None]:
"""Context manager to set ContextVars."""
reset_token = active_message_pump.set(self)
try:
yield
finally:
active_message_pump.reset(reset_token)
async def _on_close_messages(self, message: messages.CloseMessages) -> None:
await self._close_messages()

View File

@@ -24,7 +24,6 @@ import rich.repr
from . import events
from ._callback import count_parameters
from ._context import active_message_pump
from ._types import (
MessageTarget,
WatchCallbackBothValuesType,
@@ -82,8 +81,8 @@ def invoke_watcher(
_rich_traceback_omit = True
param_count = count_parameters(watch_function)
reset_token = active_message_pump.set(watcher_object)
try:
with watcher_object._context():
if param_count == 2:
watch_result = cast(WatchCallbackBothValuesType, watch_function)(
old_value, value
@@ -97,8 +96,6 @@ def invoke_watcher(
watcher_object.call_next(
partial(await_watcher, watcher_object, watch_result)
)
finally:
active_message_pump.reset(reset_token)
@rich.repr.auto
@@ -203,7 +200,7 @@ class Reactive(Generic[ReactiveType]):
Args:
obj: A reactive object.
"""
getattr(obj, "__watchers", {}).clear()
getattr(obj, "_watchers", {}).clear()
getattr(obj, "__computes", []).clear()
def __set_name__(self, owner: Type[MessageTarget], name: str) -> None:
@@ -351,7 +348,7 @@ class Reactive(Generic[ReactiveType]):
# Process "global" watchers
watchers: list[tuple[Reactable, WatchCallbackType]]
watchers = getattr(obj, "__watchers", {}).get(name, [])
watchers = getattr(obj, "_watchers", {}).get(name, [])
# Remove any watchers for reactables that have since closed
if watchers:
watchers[:] = [

View File

@@ -979,11 +979,8 @@ class Screen(Generic[ScreenResultType], Widget):
callbacks = self._callbacks[:]
self._callbacks.clear()
for callback, message_pump in callbacks:
reset_token = active_message_pump.set(message_pump)
try:
with message_pump._context():
await invoke(callback)
finally:
active_message_pump.reset(reset_token)
def _invoke_later(self, callback: CallbackType, sender: MessagePump) -> None:
"""Enqueue a callback to be invoked after the screen is repainted.
@@ -1014,12 +1011,14 @@ class Screen(Generic[ScreenResultType], Widget):
)
async def _message_loop_exit(self) -> None:
await super()._message_loop_exit()
self._compositor.clear()
self._dirty_widgets.clear()
self._dirty_regions.clear()
self._arrangement_cache.clear()
self.screen_layout_refresh_signal.unsubscribe(self)
self._nodes._clear()
self._task = None
def _pop_result_callback(self) -> None:
"""Remove the latest result callback from the stack."""

View File

@@ -97,7 +97,6 @@ class Timer:
self._active.set()
self._task.cancel()
self._task = None
return self._task
@classmethod
async def _stop_all(cls, timers: Iterable[Timer]) -> None:

View File

@@ -49,7 +49,7 @@ from . import constants, errors, events, messages
from ._animator import DEFAULT_EASING, Animatable, BoundAnimator, EasingFunction
from ._arrange import DockArrangeResult, arrange
from ._compose import compose
from ._context import NoActiveAppError, active_app
from ._context import NoActiveAppError
from ._debug import get_caller_file_and_line
from ._dispatch_key import dispatch_key
from ._easing import DEFAULT_SCROLL_EASING
@@ -1197,7 +1197,6 @@ class Widget(DOMNode):
await self.query_children("*").exclude(".-textual-system").remove()
if self.is_attached:
compose_nodes = compose(self)
print("COMPOSE", compose_nodes)
await self.mount_all(compose_nodes)
def _post_register(self, app: App) -> None:
@@ -1885,7 +1884,7 @@ class Widget(DOMNode):
Returns:
A Rich console object.
"""
return active_app.get().console
return self.app.console
@property
def _has_relative_children_width(self) -> bool:

View File

@@ -96,9 +96,6 @@ class FooterKey(Widget):
if tooltip:
self.tooltip = tooltip
def __repr__(self) -> str:
return f"FooterKey({self._parent!r})"
def render(self) -> Text:
key_style = self.get_component_rich_style("footer-key--key")
description_style = self.get_component_rich_style("footer-key--description")

View File

@@ -359,30 +359,30 @@ class Worker(Generic[ResultType]):
Args:
app: App instance.
"""
app._set_active()
active_worker.set(self)
with app._context():
active_worker.set(self)
self.state = WorkerState.RUNNING
app.log.worker(self)
try:
self._result = await self.run()
except asyncio.CancelledError as error:
self.state = WorkerState.CANCELLED
self._error = error
self.state = WorkerState.RUNNING
app.log.worker(self)
except Exception as error:
self.state = WorkerState.ERROR
self._error = error
app.log.worker(self, "failed", repr(error))
from rich.traceback import Traceback
try:
self._result = await self.run()
except asyncio.CancelledError as error:
self.state = WorkerState.CANCELLED
self._error = error
app.log.worker(self)
except Exception as error:
self.state = WorkerState.ERROR
self._error = error
app.log.worker(self, "failed", repr(error))
from rich.traceback import Traceback
app.log.worker(Traceback())
if self.exit_on_error:
worker_failed = WorkerFailed(self._error)
app._handle_exception(worker_failed)
else:
self.state = WorkerState.SUCCESS
app.log.worker(self)
app.log.worker(Traceback())
if self.exit_on_error:
worker_failed = WorkerFailed(self._error)
app._handle_exception(worker_failed)
else:
self.state = WorkerState.SUCCESS
app.log.worker(self)
def _start(
self, app: App, done_callback: Callable[[Worker], None] | None = None

View File

@@ -19,11 +19,11 @@ async def app():
yield self.horizontal
app = HorizontalAutoWidth()
async with app.run_test():
yield app
yield app
async def test_horizontal_get_content_width(app):
size = app.screen.size
width = app.horizontal.get_content_width(size, size)
assert width == 15
async with app.run_test():
size = app.screen.size
width = app.horizontal.get_content_width(size, size)
assert width == 15

View File

@@ -22,46 +22,47 @@ class ChildrenFocusableOnly(Widget, can_focus=False, can_focus_children=True):
@pytest.fixture
def screen() -> Screen:
app = App()
app._set_active()
app.push_screen(Screen())
screen = app.screen
with app._context():
app.push_screen(Screen())
# The classes even/odd alternate along the focus chain.
# The classes in/out identify nested widgets.
screen._add_children(
Focusable(id="foo", classes="a"),
NonFocusable(id="bar"),
Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b"),
NonFocusable(Focusable(id="Jessica", classes="a"), id="container2"),
Focusable(id="baz", classes="b"),
ChildrenFocusableOnly(Focusable(id="child", classes="c")),
)
screen = app.screen
return screen
# The classes even/odd alternate along the focus chain.
# The classes in/out identify nested widgets.
screen._add_children(
Focusable(id="foo", classes="a"),
NonFocusable(id="bar"),
Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b"),
NonFocusable(Focusable(id="Jessica", classes="a"), id="container2"),
Focusable(id="baz", classes="b"),
ChildrenFocusableOnly(Focusable(id="child", classes="c")),
)
return screen
def test_focus_chain():
app = App()
app._set_active()
app.push_screen(Screen())
with app._context():
app.push_screen(Screen())
screen = app.screen
screen = app.screen
# Check empty focus chain
assert not screen.focus_chain
# Check empty focus chain
assert not screen.focus_chain
app.screen._add_children(
Focusable(id="foo"),
NonFocusable(id="bar"),
Focusable(Focusable(id="Paul"), id="container1"),
NonFocusable(Focusable(id="Jessica"), id="container2"),
Focusable(id="baz"),
ChildrenFocusableOnly(Focusable(id="child")),
)
app.screen._add_children(
Focusable(id="foo"),
NonFocusable(id="bar"),
Focusable(Focusable(id="Paul"), id="container1"),
NonFocusable(Focusable(id="Jessica"), id="container2"),
Focusable(id="baz"),
ChildrenFocusableOnly(Focusable(id="child")),
)
focus_chain = [widget.id for widget in screen.focus_chain]
assert focus_chain == ["foo", "container1", "Paul", "baz", "child"]
focus_chain = [widget.id for widget in screen.focus_chain]
assert focus_chain == ["foo", "container1", "Paul", "baz", "child"]
def test_allow_focus():
@@ -90,18 +91,19 @@ def test_allow_focus():
return False
app = App()
app._set_active()
app.push_screen(Screen())
app.screen._add_children(
Focusable(id="foo"),
NonFocusable(id="bar"),
FocusableContainer(Button("egg", id="egg")),
NonFocusableContainer(Button("EGG", id="qux")),
)
assert [widget.id for widget in app.screen.focus_chain] == ["foo", "egg"]
assert focusable_allow_focus_called
assert non_focusable_allow_focus_called
with app._context():
app.push_screen(Screen())
app.screen._add_children(
Focusable(id="foo"),
NonFocusable(id="bar"),
FocusableContainer(Button("egg", id="egg")),
NonFocusableContainer(Button("EGG", id="qux")),
)
assert [widget.id for widget in app.screen.focus_chain] == ["foo", "egg"]
assert focusable_allow_focus_called
assert non_focusable_allow_focus_called
def test_focus_next_and_previous(screen: Screen):
@@ -188,47 +190,47 @@ def test_focus_next_and_previous_with_str_selector(screen: Screen):
def test_focus_next_and_previous_with_type_selector_without_self():
"""Test moving the focus with a selector that does not match the currently focused node."""
app = App()
app._set_active()
app.push_screen(Screen())
with app._context():
app.push_screen(Screen())
screen = app.screen
screen = app.screen
from textual.containers import Horizontal, VerticalScroll
from textual.widgets import Button, Input, Switch
from textual.containers import Horizontal, VerticalScroll
from textual.widgets import Button, Input, Switch
screen._add_children(
VerticalScroll(
Horizontal(
Input(id="w3"),
Switch(id="w4"),
Input(id="w5"),
Button(id="w6"),
Switch(id="w7"),
id="w2",
),
Horizontal(
Button(id="w9"),
Switch(id="w10"),
Button(id="w11"),
Input(id="w12"),
Input(id="w13"),
id="w8",
),
id="w1",
screen._add_children(
VerticalScroll(
Horizontal(
Input(id="w3"),
Switch(id="w4"),
Input(id="w5"),
Button(id="w6"),
Switch(id="w7"),
id="w2",
),
Horizontal(
Button(id="w9"),
Switch(id="w10"),
Button(id="w11"),
Input(id="w12"),
Input(id="w13"),
id="w8",
),
id="w1",
)
)
)
screen.set_focus(screen.query_one("#w3"))
assert screen.focused.id == "w3"
screen.set_focus(screen.query_one("#w3"))
assert screen.focused.id == "w3"
assert screen.focus_next(Button).id == "w6"
assert screen.focus_next(Switch).id == "w7"
assert screen.focus_next(Input).id == "w12"
assert screen.focus_next(Button).id == "w6"
assert screen.focus_next(Switch).id == "w7"
assert screen.focus_next(Input).id == "w12"
assert screen.focus_previous(Button).id == "w11"
assert screen.focus_previous(Switch).id == "w10"
assert screen.focus_previous(Button).id == "w9"
assert screen.focus_previous(Input).id == "w5"
assert screen.focus_previous(Button).id == "w11"
assert screen.focus_previous(Switch).id == "w10"
assert screen.focus_previous(Button).id == "w9"
assert screen.focus_previous(Input).id == "w5"
def test_focus_next_and_previous_with_str_selector_without_self(screen: Screen):

View File

@@ -30,14 +30,15 @@ class ListPathApp(App[None]):
@pytest.mark.parametrize(
"app,expected_css_path_attribute",
"app_class,expected_css_path_attribute",
[
(RelativePathObjectApp(), [APP_DIR / "test.tcss"]),
(RelativePathStrApp(), [APP_DIR / "test.tcss"]),
(AbsolutePathObjectApp(), [Path("/tmp/test.tcss")]),
(AbsolutePathStrApp(), [Path("/tmp/test.tcss")]),
(ListPathApp(), [APP_DIR / "test.tcss", Path("/another/path.tcss")]),
(RelativePathObjectApp, [APP_DIR / "test.tcss"]),
(RelativePathStrApp, [APP_DIR / "test.tcss"]),
(AbsolutePathObjectApp, [Path("/tmp/test.tcss")]),
(AbsolutePathStrApp, [Path("/tmp/test.tcss")]),
(ListPathApp, [APP_DIR / "test.tcss", Path("/another/path.tcss")]),
],
)
def test_css_paths_of_various_types(app, expected_css_path_attribute):
def test_css_paths_of_various_types(app_class, expected_css_path_attribute):
app = app_class()
assert app.css_path == [path.absolute() for path in expected_css_path_attribute]

View File

@@ -74,89 +74,88 @@ async def test_screens():
# There should be nothing in the children since the app hasn't run yet
assert not app._nodes
assert not app.children
app._set_active()
with app._context():
with pytest.raises(ScreenStackError):
app.screen
with pytest.raises(ScreenStackError):
app.screen
assert not app._installed_screens
assert not app._installed_screens
screen1 = Screen(name="screen1")
screen2 = Screen(name="screen2")
screen3 = Screen(name="screen3")
screen1 = Screen(name="screen1")
screen2 = Screen(name="screen2")
screen3 = Screen(name="screen3")
# installs screens
app.install_screen(screen1, "screen1")
app.install_screen(screen2, "screen2")
# installs screens
app.install_screen(screen1, "screen1")
app.install_screen(screen2, "screen2")
# Installing a screen does not add it to the DOM
assert not app._nodes
assert not app.children
# Installing a screen does not add it to the DOM
assert not app._nodes
assert not app.children
# Check they are installed
assert app.is_screen_installed("screen1")
assert app.is_screen_installed("screen2")
# Check they are installed
assert app.is_screen_installed("screen1")
assert app.is_screen_installed("screen2")
assert app.get_screen("screen1") is screen1
with pytest.raises(KeyError):
app.get_screen("foo")
assert app.get_screen("screen1") is screen1
with pytest.raises(KeyError):
app.get_screen("foo")
# Check screen3 is not installed
assert not app.is_screen_installed("screen3")
# Check screen3 is not installed
assert not app.is_screen_installed("screen3")
# Installs screen3
app.install_screen(screen3, "screen3")
# Confirm installed
assert app.is_screen_installed("screen3")
# Installs screen3
app.install_screen(screen3, "screen3")
# Confirm installed
assert app.is_screen_installed("screen3")
# Check screen stack is empty
assert app.screen_stack == []
# Push a screen
await app.push_screen("screen1")
# Check it is on the stack
assert app.screen_stack == [screen1]
# Check it is current
assert app.screen is screen1
# There should be one item in the children view
assert app.children == (screen1,)
# Check screen stack is empty
assert app.screen_stack == []
# Push a screen
await app.push_screen("screen1")
# Check it is on the stack
assert app.screen_stack == [screen1]
# Check it is current
assert app.screen is screen1
# There should be one item in the children view
assert app.children == (screen1,)
# Switch to another screen
await app.switch_screen("screen2")
# Check it has changed the stack and that it is current
assert app.screen_stack == [screen2]
assert app.screen is screen2
assert app.children == (screen2,)
# Switch to another screen
await app.switch_screen("screen2")
# Check it has changed the stack and that it is current
assert app.screen_stack == [screen2]
assert app.screen is screen2
assert app.children == (screen2,)
# Push another screen
await app.push_screen("screen3")
assert app.screen_stack == [screen2, screen3]
assert app.screen is screen3
# Only the current screen is in children
assert app.children == (screen3,)
# Push another screen
await app.push_screen("screen3")
assert app.screen_stack == [screen2, screen3]
assert app.screen is screen3
# Only the current screen is in children
assert app.children == (screen3,)
# Pop a screen
await app.pop_screen()
assert app.screen is screen2
assert app.screen_stack == [screen2]
# Pop a screen
await app.pop_screen()
assert app.screen is screen2
assert app.screen_stack == [screen2]
# Uninstall screens
app.uninstall_screen(screen1)
assert not app.is_screen_installed(screen1)
app.uninstall_screen("screen3")
assert not app.is_screen_installed(screen1)
# Uninstall screens
app.uninstall_screen(screen1)
assert not app.is_screen_installed(screen1)
app.uninstall_screen("screen3")
assert not app.is_screen_installed(screen1)
# Check we can't uninstall a screen on the stack
with pytest.raises(ScreenStackError):
app.uninstall_screen(screen2)
# Check we can't uninstall a screen on the stack
with pytest.raises(ScreenStackError):
app.uninstall_screen(screen2)
# Check we can't pop last screen
with pytest.raises(ScreenStackError):
app.pop_screen()
# Check we can't pop last screen
with pytest.raises(ScreenStackError):
app.pop_screen()
screen1.remove()
screen2.remove()
screen3.remove()
await app._shutdown()
screen1.remove()
screen2.remove()
screen3.remove()
await app._shutdown()
async def test_auto_focus_on_screen_if_app_auto_focus_is_none():

View File

@@ -50,4 +50,7 @@ async def test_unmount() -> None:
"MyScreen#main",
]
print(unmount_ids)
print(expected)
assert unmount_ids == expected

View File

@@ -57,22 +57,21 @@ def test_widget_content_width():
widget3 = TextWidget("foo\nbar\nbaz", id="widget3")
app = App()
app._set_active()
with app._context():
width = widget1.get_content_width(Size(20, 20), Size(80, 24))
height = widget1.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 1
width = widget1.get_content_width(Size(20, 20), Size(80, 24))
height = widget1.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 1
width = widget2.get_content_width(Size(20, 20), Size(80, 24))
height = widget2.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 2
width = widget2.get_content_width(Size(20, 20), Size(80, 24))
height = widget2.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 2
width = widget3.get_content_width(Size(20, 20), Size(80, 24))
height = widget3.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 3
width = widget3.get_content_width(Size(20, 20), Size(80, 24))
height = widget3.get_content_height(Size(20, 20), Size(80, 24), width)
assert width == 3
assert height == 3
class GetByIdApp(App):
@@ -87,34 +86,38 @@ class GetByIdApp(App):
id="parent",
)
@property
def parent(self) -> Widget:
return self.query_one("#parent")
@pytest.fixture
async def hierarchy_app():
app = GetByIdApp()
async with app.run_test():
yield app
yield app
@pytest.fixture
async def parent(hierarchy_app):
yield hierarchy_app.get_widget_by_id("parent")
async def test_get_child_by_id_gets_first_child(hierarchy_app):
async with hierarchy_app.run_test():
parent = hierarchy_app.parent
child = parent.get_child_by_id(id="child1")
assert child.id == "child1"
assert child.get_child_by_id(id="grandchild1").id == "grandchild1"
assert parent.get_child_by_id(id="child2").id == "child2"
def test_get_child_by_id_gets_first_child(parent):
child = parent.get_child_by_id(id="child1")
assert child.id == "child1"
assert child.get_child_by_id(id="grandchild1").id == "grandchild1"
assert parent.get_child_by_id(id="child2").id == "child2"
async def test_get_child_by_id_no_matching_child(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
with pytest.raises(NoMatches):
parent.get_child_by_id(id="doesnt-exist")
def test_get_child_by_id_no_matching_child(parent):
with pytest.raises(NoMatches):
parent.get_child_by_id(id="doesnt-exist")
def test_get_child_by_id_only_immediate_descendents(parent):
with pytest.raises(NoMatches):
parent.get_child_by_id(id="grandchild1")
async def test_get_child_by_id_only_immediate_descendents(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
with pytest.raises(NoMatches):
parent.get_child_by_id(id="grandchild1")
async def test_get_child_by_type():
@@ -135,51 +138,65 @@ async def test_get_child_by_type():
app.get_child_by_type(Label)
def test_get_widget_by_id_no_matching_child(parent):
with pytest.raises(NoMatches):
parent.get_widget_by_id(id="i-dont-exist")
async def test_get_widget_by_id_no_matching_child(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
with pytest.raises(NoMatches):
parent.get_widget_by_id(id="i-dont-exist")
def test_get_widget_by_id_non_immediate_descendants(parent):
result = parent.get_widget_by_id("grandchild1")
assert result.id == "grandchild1"
async def test_get_widget_by_id_non_immediate_descendants(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
result = parent.get_widget_by_id("grandchild1")
assert result.id == "grandchild1"
def test_get_widget_by_id_immediate_descendants(parent):
result = parent.get_widget_by_id("child1")
assert result.id == "child1"
async def test_get_widget_by_id_immediate_descendants(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
result = parent.get_widget_by_id("child1")
assert result.id == "child1"
def test_get_widget_by_id_doesnt_return_self(parent):
with pytest.raises(NoMatches):
parent.get_widget_by_id("parent")
async def test_get_widget_by_id_doesnt_return_self(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
with pytest.raises(NoMatches):
parent.get_widget_by_id("parent")
def test_get_widgets_app_delegated(hierarchy_app, parent):
async def test_get_widgets_app_delegated(hierarchy_app):
# Check that get_child_by_id finds the parent, which is a child of the default Screen
queried_parent = hierarchy_app.get_child_by_id("parent")
assert queried_parent is parent
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
queried_parent = hierarchy_app.get_child_by_id("parent")
assert queried_parent is parent
# Check that the grandchild (descendant of the default screen) is found
grandchild = hierarchy_app.get_widget_by_id("grandchild1")
assert grandchild.id == "grandchild1"
# Check that the grandchild (descendant of the default screen) is found
grandchild = hierarchy_app.get_widget_by_id("grandchild1")
assert grandchild.id == "grandchild1"
def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(parent):
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
async def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
with pytest.raises(MountError):
parent.mount(widget1, widget2)
with pytest.raises(MountError):
parent.mount(widget1, widget2)
def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(parent):
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
async def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(hierarchy_app):
async with hierarchy_app.run_test() as pilot:
parent = pilot.app.parent
widget1 = Widget(id="hello")
widget2 = Widget(id="hello")
parent.mount(widget1)
with pytest.raises(DuplicateIds):
parent.mount(widget2)
parent.mount(widget1)
with pytest.raises(DuplicateIds):
parent.mount(widget2)
def test_get_pseudo_class_state():

View File

@@ -116,8 +116,10 @@ async def test_mount_via_app() -> None:
await pilot.app.mount(Static(), before="Static")
def test_mount_error() -> None:
async def test_mount_error() -> None:
"""Mounting a widget on an un-mounted widget should raise an error."""
with pytest.raises(MountError):
widget = Widget()
widget.mount(Static())
app = App()
async with app.run_test():
with pytest.raises(MountError):
widget = Widget()
widget.mount(Static())

View File

@@ -6,7 +6,6 @@ import pytest
from textual.app import App, ComposeResult
from textual.events import Paste
from textual.pilot import Pilot
from textual.widgets import TextArea
from textual.widgets.text_area import EditHistory, Selection
@@ -57,300 +56,346 @@ async def text_area(pilot):
return pilot.app.text_area
async def test_simple_undo_redo(pilot, text_area: TextArea):
text_area.insert("123", (0, 0))
async def test_simple_undo_redo():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.insert("123", (0, 0))
assert text_area.text == "123"
text_area.undo()
assert text_area.text == ""
text_area.redo()
assert text_area.text == "123"
assert text_area.text == "123"
text_area.undo()
assert text_area.text == ""
text_area.redo()
assert text_area.text == "123"
async def test_undo_selection_retained(pilot: Pilot, text_area: TextArea):
async def test_undo_selection_retained():
# Select a range of text and press backspace.
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((0, 0), (2, 3))
await pilot.press("backspace")
assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n"
assert text_area.selection == Selection.cursor((0, 0))
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((0, 0), (2, 3))
await pilot.press("backspace")
assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n"
assert text_area.selection == Selection.cursor((0, 0))
# Undo the deletion - the text comes back, and the selection is restored.
text_area.undo()
assert text_area.selection == Selection((0, 0), (2, 3))
assert text_area.text == SIMPLE_TEXT
# Undo the deletion - the text comes back, and the selection is restored.
text_area.undo()
assert text_area.selection == Selection((0, 0), (2, 3))
assert text_area.text == SIMPLE_TEXT
# Redo the deletion - the text is gone again. The selection goes to the post-delete location.
text_area.redo()
assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n"
assert text_area.selection == Selection.cursor((0, 0))
# Redo the deletion - the text is gone again. The selection goes to the post-delete location.
text_area.redo()
assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n"
assert text_area.selection == Selection.cursor((0, 0))
async def test_undo_checkpoint_created_on_cursor_move(
pilot: Pilot, text_area: TextArea
):
text_area.text = SIMPLE_TEXT
# Characters are inserted on line 0 and 1.
checkpoint_one = text_area.text
checkpoint_one_selection = text_area.selection
await pilot.press("1") # Added to initial batch.
async def test_undo_checkpoint_created_on_cursor_move():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = SIMPLE_TEXT
# Characters are inserted on line 0 and 1.
checkpoint_one = text_area.text
checkpoint_one_selection = text_area.selection
await pilot.press("1") # Added to initial batch.
# This cursor movement ensures a new checkpoint is created.
post_insert_one_location = text_area.selection
await pilot.press("down")
# This cursor movement ensures a new checkpoint is created.
post_insert_one_location = text_area.selection
await pilot.press("down")
checkpoint_two = text_area.text
checkpoint_two_selection = text_area.selection
await pilot.press("2") # Added to new batch.
checkpoint_two = text_area.text
checkpoint_two_selection = text_area.selection
await pilot.press("2") # Added to new batch.
checkpoint_three = text_area.text
checkpoint_three_selection = text_area.selection
checkpoint_three = text_area.text
checkpoint_three_selection = text_area.selection
# Going back to checkpoint two
text_area.undo()
assert text_area.text == checkpoint_two
assert text_area.selection == checkpoint_two_selection
# Going back to checkpoint two
text_area.undo()
assert text_area.text == checkpoint_two
assert text_area.selection == checkpoint_two_selection
# Back again to checkpoint one (initial state)
text_area.undo()
assert text_area.text == checkpoint_one
assert text_area.selection == checkpoint_one_selection
# Back again to checkpoint one (initial state)
text_area.undo()
assert text_area.text == checkpoint_one
assert text_area.selection == checkpoint_one_selection
# Redo to move forward to checkpoint two.
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == post_insert_one_location
# Redo to move forward to checkpoint two.
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == post_insert_one_location
# Redo to move forward to checkpoint three.
text_area.redo()
assert text_area.text == checkpoint_three
assert text_area.selection == checkpoint_three_selection
# Redo to move forward to checkpoint three.
text_area.redo()
assert text_area.text == checkpoint_three
assert text_area.selection == checkpoint_three_selection
async def test_setting_text_property_resets_history(pilot: Pilot, text_area: TextArea):
await pilot.press("1")
async def test_setting_text_property_resets_history():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
await pilot.press("1")
# Programmatically setting text, which should invalidate the history
text = "Hello, world!"
text_area.text = text
# Programmatically setting text, which should invalidate the history
text = "Hello, world!"
text_area.text = text
# The undo doesn't do anything, since we set the `text` property.
text_area.undo()
assert text_area.text == text
# The undo doesn't do anything, since we set the `text` property.
text_area.undo()
assert text_area.text == text
async def test_edits_batched_by_time(pilot: Pilot, text_area: TextArea):
# The first "12" is batched since they happen within 2 seconds.
text_area.history.mock_time = 0
await pilot.press("1")
async def test_edits_batched_by_time():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
# The first "12" is batched since they happen within 2 seconds.
text_area.history.mock_time = 0
await pilot.press("1")
text_area.history.mock_time = 1.0
await pilot.press("2")
text_area.history.mock_time = 1.0
await pilot.press("2")
# Since "3" appears 10 seconds later, it's in a separate batch.
text_area.history.mock_time += 10.0
await pilot.press("3")
# Since "3" appears 10 seconds later, it's in a separate batch.
text_area.history.mock_time += 10.0
await pilot.press("3")
assert text_area.text == "123"
assert text_area.text == "123"
text_area.undo()
assert text_area.text == "12"
text_area.undo()
assert text_area.text == "12"
text_area.undo()
assert text_area.text == ""
text_area.undo()
assert text_area.text == ""
async def test_undo_checkpoint_character_limit_reached(
pilot: Pilot, text_area: TextArea
):
await pilot.press("1")
# Since the insertion below is > 100 characters it goes to a new batch.
text_area.insert("2" * 120)
async def test_undo_checkpoint_character_limit_reached():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
await pilot.press("1")
# Since the insertion below is > 100 characters it goes to a new batch.
text_area.insert("2" * 120)
text_area.undo()
assert text_area.text == "1"
text_area.undo()
assert text_area.text == ""
text_area.undo()
assert text_area.text == "1"
text_area.undo()
assert text_area.text == ""
async def test_redo_with_no_undo_is_noop(text_area: TextArea):
text_area.text = SIMPLE_TEXT
text_area.redo()
assert text_area.text == SIMPLE_TEXT
async def test_redo_with_no_undo_is_noop():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = SIMPLE_TEXT
text_area.redo()
assert text_area.text == SIMPLE_TEXT
async def test_undo_with_empty_undo_stack_is_noop(text_area: TextArea):
text_area.text = SIMPLE_TEXT
text_area.undo()
assert text_area.text == SIMPLE_TEXT
async def test_undo_with_empty_undo_stack_is_noop():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = SIMPLE_TEXT
text_area.undo()
assert text_area.text == SIMPLE_TEXT
async def test_redo_stack_cleared_on_edit(pilot: Pilot, text_area: TextArea):
text_area.text = ""
await pilot.press("1")
text_area.history.checkpoint()
await pilot.press("2")
text_area.history.checkpoint()
await pilot.press("3")
async def test_redo_stack_cleared_on_edit():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = ""
await pilot.press("1")
text_area.history.checkpoint()
await pilot.press("2")
text_area.history.checkpoint()
await pilot.press("3")
text_area.undo()
text_area.undo()
text_area.undo()
assert text_area.text == ""
assert text_area.selection == Selection.cursor((0, 0))
text_area.undo()
text_area.undo()
text_area.undo()
assert text_area.text == ""
assert text_area.selection == Selection.cursor((0, 0))
# Redo stack has 3 edits in it now.
await pilot.press("f")
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
# Redo stack has 3 edits in it now.
await pilot.press("f")
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
# Redo stack is cleared because of the edit, so redo has no effect.
text_area.redo()
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
text_area.redo()
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
# Redo stack is cleared because of the edit, so redo has no effect.
text_area.redo()
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
text_area.redo()
assert text_area.text == "f"
assert text_area.selection == Selection.cursor((0, 1))
async def test_inserts_not_batched_with_deletes(pilot: Pilot, text_area: TextArea):
async def test_inserts_not_batched_with_deletes():
# 3 batches here: __1___ ___________2____________ __3__
await pilot.press(*"123", "backspace", "backspace", *"23")
assert text_area.text == "123"
app = TextAreaApp()
# Undo batch 1: the "23" insertion.
text_area.undo()
assert text_area.text == "1"
async with app.run_test() as pilot:
text_area = app.text_area
await pilot.press(*"123", "backspace", "backspace", *"23")
# Undo batch 2: the double backspace.
text_area.undo()
assert text_area.text == "123"
assert text_area.text == "123"
# Undo batch 3: the "123" insertion.
text_area.undo()
assert text_area.text == ""
# Undo batch 1: the "23" insertion.
text_area.undo()
assert text_area.text == "1"
# Undo batch 2: the double backspace.
text_area.undo()
assert text_area.text == "123"
# Undo batch 3: the "123" insertion.
text_area.undo()
assert text_area.text == ""
async def test_paste_is_an_isolated_batch(pilot: Pilot, text_area: TextArea):
pilot.app.post_message(Paste("hello "))
pilot.app.post_message(Paste("world"))
await pilot.pause()
async def test_paste_is_an_isolated_batch():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
pilot.app.post_message(Paste("hello "))
pilot.app.post_message(Paste("world"))
await pilot.pause()
assert text_area.text == "hello world"
assert text_area.text == "hello world"
await pilot.press("!")
await pilot.press("!")
# The insertion of "!" does not get batched with the paste of "world".
text_area.undo()
assert text_area.text == "hello world"
# The insertion of "!" does not get batched with the paste of "world".
text_area.undo()
assert text_area.text == "hello world"
text_area.undo()
assert text_area.text == "hello "
text_area.undo()
assert text_area.text == "hello "
text_area.undo()
assert text_area.text == ""
text_area.undo()
assert text_area.text == ""
async def test_focus_creates_checkpoint(pilot: Pilot, text_area: TextArea):
await pilot.press(*"123")
text_area.has_focus = False
text_area.has_focus = True
await pilot.press(*"456")
assert text_area.text == "123456"
async def test_focus_creates_checkpoint():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
await pilot.press(*"123")
text_area.has_focus = False
text_area.has_focus = True
await pilot.press(*"456")
assert text_area.text == "123456"
# Since we re-focused, a checkpoint exists between 123 and 456,
# so when we use undo, only the 456 is removed.
text_area.undo()
assert text_area.text == "123"
# Since we re-focused, a checkpoint exists between 123 and 456,
# so when we use undo, only the 456 is removed.
text_area.undo()
assert text_area.text == "123"
async def test_undo_redo_deletions_batched(pilot: Pilot, text_area: TextArea):
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((0, 2), (1, 2))
async def test_undo_redo_deletions_batched():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((0, 2), (1, 2))
# Perform a single delete of some selected text. It'll live in it's own
# batch since it's a multi-line operation.
await pilot.press("backspace")
checkpoint_one = "ABHIJ\nKLMNO\nPQRST\nUVWXY\nZ\n"
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# Perform a single delete of some selected text. It'll live in it's own
# batch since it's a multi-line operation.
await pilot.press("backspace")
checkpoint_one = "ABHIJ\nKLMNO\nPQRST\nUVWXY\nZ\n"
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# Pressing backspace a few times to delete more characters.
await pilot.press("backspace", "backspace", "backspace")
checkpoint_two = "HIJ\nKLMNO\nPQRST\nUVWXY\nZ\n"
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
# Pressing backspace a few times to delete more characters.
await pilot.press("backspace", "backspace", "backspace")
checkpoint_two = "HIJ\nKLMNO\nPQRST\nUVWXY\nZ\n"
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
# When we undo, the 3 deletions above should be batched, but not
# the original deletion since it contains a newline character.
text_area.undo()
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# When we undo, the 3 deletions above should be batched, but not
# the original deletion since it contains a newline character.
text_area.undo()
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# Undoing again restores us back to our initial text and selection.
text_area.undo()
assert text_area.text == SIMPLE_TEXT
assert text_area.selection == Selection((0, 2), (1, 2))
# Undoing again restores us back to our initial text and selection.
text_area.undo()
assert text_area.text == SIMPLE_TEXT
assert text_area.selection == Selection((0, 2), (1, 2))
# At this point, the undo stack contains two items, so we can redo twice.
# At this point, the undo stack contains two items, so we can redo twice.
# Redo to go back to checkpoint one.
text_area.redo()
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# Redo to go back to checkpoint one.
text_area.redo()
assert text_area.text == checkpoint_one
assert text_area.selection == Selection.cursor((0, 2))
# Redo again to go back to checkpoint two
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
# Redo again to go back to checkpoint two
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
# Redo again does nothing.
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
# Redo again does nothing.
text_area.redo()
assert text_area.text == checkpoint_two
assert text_area.selection == Selection.cursor((0, 0))
async def test_max_checkpoints(pilot: Pilot, text_area: TextArea):
assert len(text_area.history.undo_stack) == 0
for index in range(MAX_CHECKPOINTS):
# Press enter since that will ensure a checkpoint is created.
async def test_max_checkpoints():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
assert len(text_area.history.undo_stack) == 0
for index in range(MAX_CHECKPOINTS):
# Press enter since that will ensure a checkpoint is created.
await pilot.press("enter")
assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS
await pilot.press("enter")
assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS
await pilot.press("enter")
# Ensure we don't go over the limit.
assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS
# Ensure we don't go over the limit.
assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS
async def test_redo_stack(pilot: Pilot, text_area: TextArea):
assert len(text_area.history.redo_stack) == 0
await pilot.press("enter")
await pilot.press(*"123")
assert len(text_area.history.undo_stack) == 2
assert len(text_area.history.redo_stack) == 0
text_area.undo()
assert len(text_area.history.undo_stack) == 1
assert len(text_area.history.redo_stack) == 1
text_area.undo()
assert len(text_area.history.undo_stack) == 0
assert len(text_area.history.redo_stack) == 2
text_area.redo()
assert len(text_area.history.undo_stack) == 1
assert len(text_area.history.redo_stack) == 1
text_area.redo()
assert len(text_area.history.undo_stack) == 2
assert len(text_area.history.redo_stack) == 0
async def test_redo_stack():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
assert len(text_area.history.redo_stack) == 0
await pilot.press("enter")
await pilot.press(*"123")
assert len(text_area.history.undo_stack) == 2
assert len(text_area.history.redo_stack) == 0
text_area.undo()
assert len(text_area.history.undo_stack) == 1
assert len(text_area.history.redo_stack) == 1
text_area.undo()
assert len(text_area.history.undo_stack) == 0
assert len(text_area.history.redo_stack) == 2
text_area.redo()
assert len(text_area.history.undo_stack) == 1
assert len(text_area.history.redo_stack) == 1
text_area.redo()
assert len(text_area.history.undo_stack) == 2
assert len(text_area.history.redo_stack) == 0
async def test_backward_selection_undo_redo(pilot: Pilot, text_area: TextArea):
# Failed prior to https://github.com/Textualize/textual/pull/4352
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((3, 2), (0, 0))
async def test_backward_selection_undo_redo():
app = TextAreaApp()
async with app.run_test() as pilot:
text_area = app.text_area
# Failed prior to https://github.com/Textualize/textual/pull/4352
text_area.text = SIMPLE_TEXT
text_area.selection = Selection((3, 2), (0, 0))
await pilot.press("a")
await pilot.press("a")
text_area.undo()
await pilot.press("down", "down", "down", "down")
text_area.undo()
await pilot.press("down", "down", "down", "down")
assert text_area.text == SIMPLE_TEXT
assert text_area.text == SIMPLE_TEXT