diff --git a/src/textual/app.py b/src/textual/app.py index 05ae29922..7e02572d4 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -159,6 +159,38 @@ class ScreenStackError(ScreenError): """Raised when trying to manipulate the screen stack incorrectly.""" +class ModeError(Exception): + """Base class for exceptions related to modes.""" + + +class InvalidModeError(ModeError): + """Raised if there is an issue with a mode name.""" + + +class UnknownModeError(ModeError): + """Raised when attempting to use a mode that is not known.""" + + +class ActiveModeError(ModeError): + """Raised when attempting to remove the currently active mode.""" + + +class ModeError(Exception): + """Base class for exceptions related to modes.""" + + +class InvalidModeError(ModeError): + """Raised if there is an issue with a mode name.""" + + +class UnknownModeError(ModeError): + """Raised when attempting to use a mode that is not known.""" + + +class ActiveModeError(ModeError): + """Raised when attempting to remove the currently active mode.""" + + class CssPathError(Exception): """Raised when supplied CSS path(s) are invalid.""" @@ -212,6 +244,35 @@ class App(Generic[ReturnType], DOMNode): } """ + MODES: ClassVar[dict[str, str | Screen | Callable[[], Screen]]] = {} + """Modes associated with the app and their base screens. + + The base screen is the screen at the bottom of the mode stack. You can think of + it as the default screen for that stack. + The base screens can be names of screens listed in [SCREENS][textual.app.App.SCREENS], + [`Screen`][textual.screen.Screen] instances, or callables that return screens. + + Example: + ```py + class HelpScreen(Screen[None]): + ... + + class MainAppScreen(Screen[None]): + ... + + class MyApp(App[None]): + MODES = { + "default": "main", + "help": HelpScreen, + } + + SCREENS = { + "main": MainAppScreen, + } + + ... + ``` + """ SCREENS: ClassVar[dict[str, Screen | Callable[[], Screen]]] = {} """Screens associated with the app for the lifetime of the app.""" _BASE_PATH: str | None = None @@ -296,7 +357,10 @@ class App(Generic[ReturnType], DOMNode): self._workers = WorkerManager(self) self.error_console = Console(markup=False, stderr=True) self.driver_class = driver_class or self.get_driver_class() - self._screen_stack: list[Screen] = [] + self._screen_stacks: dict[str, list[Screen]] = {"_default": []} + """A stack of screens per mode.""" + self._current_mode: str = "_default" + """The current mode the app is in.""" self._sync_available = False self.mouse_over: Widget | None = None @@ -528,7 +592,19 @@ class App(Generic[ReturnType], DOMNode): Returns: A snapshot of the current state of the screen stack. """ - return self._screen_stack.copy() + return self._screen_stacks[self._current_mode].copy() + + @property + def _screen_stack(self) -> list[Screen]: + """A reference to the current screen stack. + + Note: + Consider using [`screen_stack`][textual.app.App.screen_stack] instead. + + Returns: + A reference to the current screen stack. + """ + return self._screen_stacks[self._current_mode] def exit( self, result: ReturnType | None = None, message: RenderableType | None = None @@ -676,6 +752,8 @@ class App(Generic[ReturnType], DOMNode): """ try: return self._screen_stack[-1] + except KeyError: + raise UnknownModeError(f"No known mode {self._current_mode!r}") from None except IndexError: raise ScreenStackError("No screens on stack") from None @@ -1321,6 +1399,88 @@ class App(Generic[ReturnType], DOMNode): """ return self.mount(*widgets, before=before, after=after) + def _init_mode(self, mode: str) -> None: + """Do internal initialisation of a new screen stack mode.""" + + stack = self._screen_stacks.get(mode, []) + if not stack: + _screen = self.MODES[mode] + if callable(_screen): + screen, _ = self._get_screen(_screen()) + else: + screen, _ = self._get_screen(self.MODES[mode]) + stack.append(screen) + self._screen_stacks[mode] = [screen] + + def switch_mode(self, mode: str) -> None: + """Switch to a given mode. + + Args: + mode: The mode to switch to. + + Raises: + UnknownModeError: If trying to switch to an unknown mode. + """ + if mode not in self.MODES: + raise UnknownModeError(f"No known mode {mode!r}") + + self.screen.post_message(events.ScreenSuspend()) + self.screen.refresh() + + if mode not in self._screen_stacks: + self._init_mode(mode) + self._current_mode = mode + self.screen._screen_resized(self.size) + self.screen.post_message(events.ScreenResume()) + self.log.system(f"{self._current_mode!r} is the current mode") + self.log.system(f"{self.screen} is active") + + def add_mode( + self, mode: str, base_screen: str | Screen | Callable[[], Screen] + ) -> None: + """Adds a mode and its corresponding base screen to the app. + + Args: + mode: The new mode. + base_screen: The base screen associated with the given mode. + + Raises: + InvalidModeError: If the name of the mode is not valid/duplicated. + """ + if mode == "_default": + raise InvalidModeError("Cannot use '_default' as a custom mode.") + elif mode in self.MODES: + raise InvalidModeError(f"Duplicated mode name {mode!r}.") + + self.MODES[mode] = base_screen + + def remove_mode(self, mode: str) -> None: + """Removes a mode from the app. + + Screens that are running in the stack of that mode are scheduled for pruning. + + Args: + mode: The mode to remove. It can't be the active mode. + + Raises: + ActiveModeError: If trying to remove the active mode. + UnknownModeError: If trying to remove an unknown mode. + """ + if mode == self._current_mode: + raise ActiveModeError(f"Can't remove active mode {mode!r}") + elif mode not in self.MODES: + raise UnknownModeError(f"Unknown mode {mode!r}") + else: + del self.MODES[mode] + + if mode not in self._screen_stacks: + return + + stack = self._screen_stacks[mode] + del self._screen_stacks[mode] + for screen in reversed(stack): + self._replace_screen(screen) + def is_screen_installed(self, screen: Screen | str) -> bool: """Check if a given screen has been installed. @@ -1397,7 +1557,9 @@ class App(Generic[ReturnType], DOMNode): self.screen.refresh() screen.post_message(events.ScreenSuspend()) self.log.system(f"{screen} SUSPENDED") - if not self.is_screen_installed(screen) and screen not in self._screen_stack: + if not self.is_screen_installed(screen) and all( + screen not in stack for stack in self._screen_stacks.values() + ): screen.remove() self.log.system(f"{screen} REMOVED") return screen @@ -1498,13 +1660,13 @@ class App(Generic[ReturnType], DOMNode): if screen not in self._installed_screens: return None uninstall_screen = self._installed_screens[screen] - if uninstall_screen in self._screen_stack: + if any(uninstall_screen in stack for stack in self._screen_stacks.values()): raise ScreenStackError("Can't uninstall screen in screen stack") del self._installed_screens[screen] self.log.system(f"{uninstall_screen} UNINSTALLED name={screen!r}") return screen else: - if screen in self._screen_stack: + if any(screen in stack for stack in self._screen_stacks.values()): raise ScreenStackError("Can't uninstall screen in screen stack") for name, installed_screen in self._installed_screens.items(): if installed_screen is screen: @@ -1949,12 +2111,12 @@ class App(Generic[ReturnType], DOMNode): async def _close_all(self) -> None: """Close all message pumps.""" - # Close all screens on the stack. - for stack_screen in reversed(self._screen_stack): - if stack_screen._running: - await self._prune_node(stack_screen) - - self._screen_stack.clear() + # Close all screens on all stacks: + for stack in self._screen_stacks.values(): + for stack_screen in reversed(stack): + if stack_screen._running: + await self._prune_node(stack_screen) + stack.clear() # Close pre-defined screens. for screen in self.SCREENS.values(): @@ -2139,7 +2301,7 @@ class App(Generic[ReturnType], DOMNode): # Handle input events that haven't been forwarded # If the event has been forwarded it may have bubbled up back to the App if isinstance(event, events.Compose): - screen = Screen(id="_default") + screen = Screen(id=f"_default") self._register(self, screen) self._screen_stack.append(screen) screen.post_message(events.ScreenResume()) diff --git a/tests/test_screen_modes.py b/tests/test_screen_modes.py new file mode 100644 index 000000000..6fd5c185d --- /dev/null +++ b/tests/test_screen_modes.py @@ -0,0 +1,277 @@ +from functools import partial +from itertools import cycle +from typing import Type + +import pytest + +from textual.app import ( + ActiveModeError, + App, + ComposeResult, + InvalidModeError, + UnknownModeError, +) +from textual.screen import ModalScreen, Screen +from textual.widgets import Footer, Header, Label, TextLog + +FRUITS = cycle("apple mango strawberry banana peach pear melon watermelon".split()) + + +class ScreenBindingsMixin(Screen[None]): + BINDINGS = [ + ("1", "one", "Mode 1"), + ("2", "two", "Mode 2"), + ("p", "push", "Push rnd scrn"), + ("o", "pop_screen", "Pop"), + ("r", "remove", "Remove mode 1"), + ] + + def action_one(self) -> None: + self.app.switch_mode("one") + + def action_two(self) -> None: + self.app.switch_mode("two") + + def action_fruits(self) -> None: + self.app.switch_mode("fruits") + + def action_push(self) -> None: + self.app.push_screen(FruitModal()) + + +class BaseScreen(ScreenBindingsMixin): + def __init__(self, label): + super().__init__() + self.label = label + + def compose(self) -> ComposeResult: + yield Header() + yield Label(self.label) + yield Footer() + + def action_remove(self) -> None: + self.app.remove_mode("one") + + +class FruitModal(ModalScreen[str], ScreenBindingsMixin): + BINDINGS = [("d", "dismiss_fruit", "Dismiss")] + + def compose(self) -> ComposeResult: + yield Label(next(FRUITS)) + + +class FruitsScreen(ScreenBindingsMixin): + def compose(self) -> ComposeResult: + yield TextLog() + + +@pytest.fixture +def ModesApp(): + class ModesApp(App[None]): + MODES = { + "one": lambda: BaseScreen("one"), + "two": "screen_two", + } + + SCREENS = { + "screen_two": lambda: BaseScreen("two"), + } + + def on_mount(self): + self.switch_mode("one") + + return ModesApp + + +async def test_mode_setup(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test(): + assert isinstance(app.screen, BaseScreen) + assert str(app.screen.query_one(Label).renderable) == "one" + + +async def test_switch_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test() as pilot: + await pilot.press("2") + assert str(app.screen.query_one(Label).renderable) == "two" + await pilot.press("1") + assert str(app.screen.query_one(Label).renderable) == "one" + + +async def test_switch_same_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test() as pilot: + await pilot.press("1") + assert str(app.screen.query_one(Label).renderable) == "one" + await pilot.press("1") + assert str(app.screen.query_one(Label).renderable) == "one" + + +async def test_switch_unknown_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test(): + with pytest.raises(UnknownModeError): + app.switch_mode("unknown mode here") + + +async def test_remove_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test() as pilot: + app.switch_mode("two") + await pilot.pause() + assert str(app.screen.query_one(Label).renderable) == "two" + app.remove_mode("one") + assert "one" not in app.MODES + + +async def test_remove_active_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test(): + with pytest.raises(ActiveModeError): + app.remove_mode("one") + + +async def test_add_mode(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test() as pilot: + app.add_mode("three", BaseScreen("three")) + app.switch_mode("three") + await pilot.pause() + assert str(app.screen.query_one(Label).renderable) == "three" + + +async def test_add_mode_duplicated(ModesApp: Type[App]): + app = ModesApp() + async with app.run_test(): + with pytest.raises(InvalidModeError): + app.add_mode("one", BaseScreen("one")) + + +async def test_screen_stack_preserved(ModesApp: Type[App]): + fruits = [] + N = 5 + + app = ModesApp() + async with app.run_test() as pilot: + # Build the stack up. + for _ in range(N): + await pilot.press("p") + fruits.append(str(app.query_one(Label).renderable)) + + assert len(app.screen_stack) == N + 1 + + # Switch out and back. + await pilot.press("2") + assert len(app.screen_stack) == 1 + await pilot.press("1") + + # Check the stack. + assert len(app.screen_stack) == N + 1 + for _ in range(N): + assert str(app.query_one(Label).renderable) == fruits.pop() + await pilot.press("o") + + +async def test_inactive_stack_is_alive(): + """This tests that timers in screens outside the active stack keep going.""" + pings = [] + + class FastCounter(Screen[None]): + def compose(self) -> ComposeResult: + yield Label("fast") + + def on_mount(self) -> None: + self.set_interval(0.01, self.ping) + + def ping(self) -> None: + pings.append(str(self.app.query_one(Label).renderable)) + + def key_s(self): + self.app.switch_mode("smile") + + class SmileScreen(Screen[None]): + def compose(self) -> ComposeResult: + yield Label(":)") + + def key_s(self): + self.app.switch_mode("fast") + + class ModesApp(App[None]): + MODES = { + "fast": FastCounter, + "smile": SmileScreen, + } + + def on_mount(self) -> None: + self.switch_mode("fast") + + app = ModesApp() + async with app.run_test() as pilot: + await pilot.press("s") + assert str(app.query_one(Label).renderable) == ":)" + await pilot.press("s") + assert ":)" in pings + + +async def test_multiple_mode_callbacks(): + written = [] + + class LogScreen(Screen[None]): + def __init__(self, value): + super().__init__() + self.value = value + + def key_p(self) -> None: + self.app.push_screen(ResultScreen(self.value), written.append) + + class ResultScreen(Screen[str]): + def __init__(self, value): + super().__init__() + self.value = value + + def key_p(self) -> None: + self.dismiss(self.value) + + def key_f(self) -> None: + self.app.switch_mode("first") + + def key_o(self) -> None: + self.app.switch_mode("other") + + class ModesApp(App[None]): + MODES = { + "first": lambda: LogScreen("first"), + "other": lambda: LogScreen("other"), + } + + def on_mount(self) -> None: + self.switch_mode("first") + + def key_f(self) -> None: + self.switch_mode("first") + + def key_o(self) -> None: + self.switch_mode("other") + + app = ModesApp() + async with app.run_test() as pilot: + # Push and dismiss ResultScreen("first") + await pilot.press("p") + await pilot.press("p") + assert written == ["first"] + + # Push ResultScreen("first") + await pilot.press("p") + # Switch to LogScreen("other") + await pilot.press("o") + # Push and dismiss ResultScreen("other") + await pilot.press("p") + await pilot.press("p") + assert written == ["first", "other"] + + # Go back to ResultScreen("first") + await pilot.press("f") + # Dismiss ResultScreen("first") + await pilot.press("p") + assert written == ["first", "other", "first"]