diff --git a/docs/examples/timers/clock.py b/docs/examples/timers/clock.py index b13287748..53f412da8 100644 --- a/docs/examples/timers/clock.py +++ b/docs/examples/timers/clock.py @@ -1,6 +1,7 @@ from datetime import datetime from rich.align import Align +from rich.style import Style from textual.app import App from textual.widget import Widget @@ -10,7 +11,7 @@ class Clock(Widget): def on_mount(self): self.set_interval(1, self.refresh) - def render(self): + def render(self, style: Style): time = datetime.now().strftime("%c") return Align.center(time, vertical="middle") diff --git a/docs/examples/widgets/custom.py b/docs/examples/widgets/custom.py index 8fda42589..f35e82f41 100644 --- a/docs/examples/widgets/custom.py +++ b/docs/examples/widgets/custom.py @@ -1,4 +1,5 @@ from rich.panel import Panel +from rich.style import Style from textual.app import App from textual.reactive import Reactive @@ -9,7 +10,7 @@ class Hover(Widget): mouse_over = Reactive(False) - def render(self) -> Panel: + def render(self, style: Style) -> Panel: return Panel("Hello [b]World[/b]", style=("on red" if self.mouse_over else "")) def on_enter(self) -> None: diff --git a/e2e_tests/test_apps/basic.py b/e2e_tests/test_apps/basic.py index 7a833d843..70b1b6fa0 100644 --- a/e2e_tests/test_apps/basic.py +++ b/e2e_tests/test_apps/basic.py @@ -1,7 +1,7 @@ from pathlib import Path -from rich.align import Align from rich.console import RenderableType +from rich.style import Style from rich.syntax import Syntax from rich.text import Text @@ -53,12 +53,12 @@ lorem = Text.from_markup( class TweetHeader(Widget): - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return Text("Lorem Impsum", justify="center") class TweetBody(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return lorem @@ -67,22 +67,22 @@ class Tweet(Widget): class OptionItem(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("Option") class Error(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is an error message", justify="center") class Warning(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is a warning message", justify="center") class Success(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is a success message", justify="center") diff --git a/examples/borders.py b/examples/borders.py index cedfc76aa..62e4ec7e5 100644 --- a/examples/borders.py +++ b/examples/borders.py @@ -1,11 +1,9 @@ -from rich.console import Group from rich.padding import Padding +from rich.style import Style from rich.text import Text from textual.app import App from textual.renderables.gradient import VerticalGradient -from textual import events -from textual.widgets import Placeholder from textual.widget import Widget lorem = Text.from_markup( @@ -15,12 +13,12 @@ lorem = Text.from_markup( class Lorem(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Padding(lorem, 1) class Background(Widget): - def render(self): + def render(self, style: Style): return VerticalGradient("#212121", "#212121") diff --git a/examples/calculator.py b/examples/calculator.py index bbc28badf..8cb5dd3f8 100644 --- a/examples/calculator.py +++ b/examples/calculator.py @@ -9,13 +9,14 @@ from decimal import Decimal from rich.align import Align from rich.console import Console, ConsoleOptions, RenderResult, RenderableType from rich.padding import Padding +from rich.style import Style from rich.text import Text from textual.app import App from textual.reactive import Reactive from textual.views import GridView from textual.widget import Widget -from textual.widgets import Button, ButtonPressed +from textual.widgets import Button try: from pyfiglet import Figlet @@ -55,7 +56,7 @@ class Numbers(Widget): value = Reactive("0") - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: """Build a Rich renderable to render the calculator display.""" return Padding( Align.right(FigletText(self.value), vertical="middle"), diff --git a/pyproject.toml b/pyproject.toml index c0737d65f..a30ec2f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ includes = "src" [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +markers = [ + "integration_test: marks tests as slow integration tests(deselect with '-m \"not integration_test\"')", +] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/sandbox/align.py b/sandbox/align.py index f3f18c93d..76c3b588c 100644 --- a/sandbox/align.py +++ b/sandbox/align.py @@ -1,4 +1,4 @@ -from rich.text import Text +from rich.style import Style from textual.app import App, ComposeResult from textual.widget import Widget @@ -6,7 +6,7 @@ from textual.widgets import Static class Thing(Widget): - def render(self): + def render(self, style: Style): return "Hello, 3434 World.\n[b]Lorem impsum." diff --git a/sandbox/basic.css b/sandbox/basic.css index a2ced4868..3934f521c 100644 --- a/sandbox/basic.css +++ b/sandbox/basic.css @@ -167,7 +167,6 @@ TweetBody { OptionItem { height: 3; background: $primary; - transition: background 100ms linear; border-right: outer $primary-darken-2; border-left: hidden; content-align: center middle; @@ -224,4 +223,4 @@ Success { .horizontal { layout: horizontal -} \ No newline at end of file +} diff --git a/sandbox/basic.py b/sandbox/basic.py index c56703c57..7e9f78c28 100644 --- a/sandbox/basic.py +++ b/sandbox/basic.py @@ -1,4 +1,5 @@ from rich.console import RenderableType +from rich.style import Style from rich.syntax import Syntax from rich.text import Text @@ -50,12 +51,12 @@ lorem = Text.from_markup( class TweetHeader(Widget): - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return Text("Lorem Impsum", justify="center") class TweetBody(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return lorem @@ -64,22 +65,22 @@ class Tweet(Widget): class OptionItem(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("Option") class Error(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is an error message", justify="center") class Warning(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is a warning message", justify="center") class Success(Widget): - def render(self) -> Text: + def render(self, style: Style) -> Text: return Text("This is a success message", justify="center") diff --git a/sandbox/buttons.css b/sandbox/buttons.css index e69de29bb..df19f2b8d 100644 --- a/sandbox/buttons.css +++ b/sandbox/buttons.css @@ -0,0 +1,8 @@ +#foo { + text-style: underline; + background: rebeccapurple; +} + +#foo:hover { + background: greenyellow; +} diff --git a/sandbox/dev_sandbox.py b/sandbox/dev_sandbox.py index 0ead807d4..8d3794dba 100644 --- a/sandbox/dev_sandbox.py +++ b/sandbox/dev_sandbox.py @@ -1,12 +1,13 @@ from rich.console import RenderableType from rich.panel import Panel +from rich.style import Style from textual.app import App from textual.widget import Widget class PanelWidget(Widget): - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return Panel("hello world!", title="Title") diff --git a/sandbox/scroll_to_widget.py b/sandbox/scroll_to_widget.py new file mode 100644 index 000000000..81b0bf83c --- /dev/null +++ b/sandbox/scroll_to_widget.py @@ -0,0 +1,72 @@ +from rich.console import RenderableType +from rich.text import Text + +from textual.app import App, ComposeResult +from textual.widget import Widget +from textual.widgets import Placeholder + +placeholders_count = 12 + + +class VerticalContainer(Widget): + CSS = """ + VerticalContainer { + layout: vertical; + overflow: hidden auto; + background: darkblue; + } + + VerticalContainer Placeholder { + margin: 1 0; + height: 5; + border: solid lime; + align: center top; + } + """ + + +class Introduction(Widget): + CSS = """ + Introduction { + background: indigo; + color: white; + height: 3; + padding: 1 0; + } + """ + + def render(self) -> RenderableType: + return Text( + "Press keys 0 to 9 to scroll to the Placeholder with that ID.", + justify="center", + ) + + +class MyTestApp(App): + def compose(self) -> ComposeResult: + placeholders = [ + Placeholder(id=f"placeholder_{i}", name=f"Placeholder #{i}") + for i in range(placeholders_count) + ] + + yield VerticalContainer(Introduction(), *placeholders, id="root") + + def on_mount(self): + self.bind("q", "quit") + self.bind("t", "tree") + for widget_index in range(placeholders_count): + self.bind(str(widget_index), f"scroll_to('placeholder_{widget_index}')") + + def action_tree(self): + self.log(self.tree) + + async def action_scroll_to(self, target_placeholder_id: str): + target_placeholder = self.query(f"#{target_placeholder_id}").first() + target_placeholder_container = self.query("#root").first() + target_placeholder_container.scroll_to_widget(target_placeholder, animate=True) + + +app = MyTestApp() + +if __name__ == "__main__": + app.run() diff --git a/sandbox/tabs.py b/sandbox/tabs.py index 6bca6162e..efcfcd7c8 100644 --- a/sandbox/tabs.py +++ b/sandbox/tabs.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from rich.console import RenderableType from rich.padding import Padding from rich.rule import Rule +from rich.style import Style from textual import events from textual.app import App @@ -11,7 +12,7 @@ from textual.widgets.tabs import Tabs, Tab class Hr(Widget): - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return Rule() @@ -22,7 +23,7 @@ class Info(Widget): super().__init__() self.text = text - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return Padding(f"{self.text}", pad=(0, 1)) @@ -144,4 +145,5 @@ class BasicApp(App): self.mount(example.widget) -BasicApp.run(css_path="tabs.scss", watch_css=True, log_path="textual.log") +app = BasicApp(css_path="tabs.scss", watch_css=True, log_path="textual.log") +app.run() diff --git a/sandbox/uber.css b/sandbox/uber.css index 14848e9c2..daa40a36d 100644 --- a/sandbox/uber.css +++ b/sandbox/uber.css @@ -7,14 +7,20 @@ App.-show-focus *:focus { background: green; overflow: hidden auto; border: heavy white; + text-style: underline; } #uber1:focus-within { background: darkslateblue; } +#child2 { + text-style: underline; + background: red; +} + .list-item { - height: 10; + height: 20; color: #12a0; background: #ffffff00; } diff --git a/sandbox/uber.py b/sandbox/uber.py index 39b4ede8b..96f70fd0d 100644 --- a/sandbox/uber.py +++ b/sandbox/uber.py @@ -25,7 +25,6 @@ class BasicApp(App): first_child = Placeholder(id="child1", classes="list-item") uber1 = Widget( first_child, - Placeholder(id="child1", classes="list-item"), Placeholder(id="child2", classes="list-item"), Placeholder(id="child3", classes="list-item"), Placeholder(classes="list-item"), @@ -33,6 +32,7 @@ class BasicApp(App): Placeholder(classes="list-item"), ) self.mount(uber1=uber1) + uber1.focus() self.first_child = first_child self.uber = uber1 @@ -50,9 +50,8 @@ class BasicApp(App): def action_print(self): print( - "Printed using builtin [b blue]print[/] function:", - self.screen.tree, - sep=" - ", + "Focused widget is:", + self.focused, ) self.app.set_focus(None) diff --git a/sandbox/vertical_container.py b/sandbox/vertical_container.py index 1f4450f92..d0f470797 100644 --- a/sandbox/vertical_container.py +++ b/sandbox/vertical_container.py @@ -20,7 +20,7 @@ class VerticalContainer(Widget): VerticalContainer Placeholder { margin: 1 0; - height: 3; + height: 5; border: solid lime; align: center top; } @@ -79,10 +79,10 @@ class MyTestApp(App): placeholders = self.query("Placeholder") placeholders_count = len(placeholders) placeholder = Placeholder( - id=f"placeholder_{placeholders_count+1}", - name=f"Placeholder #{placeholders_count+1}", + id=f"placeholder_{placeholders_count}", + name=f"Placeholder #{placeholders_count}", ) - root = self.query_one("#root") + root = self.get_child("root") root.mount(placeholder) self.refresh(repaint=True, layout=True) self.refresh_css() diff --git a/src/textual/_compositor.py b/src/textual/_compositor.py index 1b1b0ed0b..80d1d14ed 100644 --- a/src/textual/_compositor.py +++ b/src/textual/_compositor.py @@ -15,12 +15,12 @@ from __future__ import annotations from operator import attrgetter, itemgetter import sys -from typing import cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING +from typing import Callable, cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING import rich.repr -from rich.console import Console, ConsoleOptions, RenderResult +from rich.console import Console, ConsoleOptions, RenderResult, RenderableType from rich.control import Control -from rich.segment import Segment, SegmentLines +from rich.segment import Segment from rich.style import Style from . import errors @@ -50,17 +50,17 @@ class ReflowResult(NamedTuple): resized: set[Widget] # Widgets that have been resized -class RenderRegion(NamedTuple): +class MapGeometry(NamedTuple): """Defines the absolute location of a Widget.""" region: Region # The region occupied by the widget order: tuple[int, ...] # A tuple of ints defining the painting order clip: Region # A region to clip the widget by (if a Widget is within a container) virtual_size: Size # The virtual size (scrollable region) of a widget if it is a container - container_size: Size # The container size (area no occupied by scrollbars) + container_size: Size # The container size (area not occupied by scrollbars) -RenderRegionMap: TypeAlias = "dict[Widget, RenderRegion]" +CompositorMap: TypeAlias = "dict[Widget, MapGeometry]" @rich.repr.auto @@ -78,6 +78,7 @@ class LayoutUpdate: new_line = Segment.line() move_to = Control.move_to for last, (y, line) in loop_last(enumerate(self.lines, self.region.y)): + yield Control.home() yield move_to(x, y) yield from line if not last: @@ -91,13 +92,40 @@ class LayoutUpdate: yield "height", height +@rich.repr.auto +class SpansUpdate: + """A renderable that applies updated spans to the screen.""" + + def __init__(self, spans: list[tuple[int, int, list[Segment]]]) -> None: + """Apply spans, which consist of a tuple of (LINE, OFFSET, SEGMENTS) + + Args: + spans (list[tuple[int, int, list[Segment]]]): A list of spans. + """ + self.spans = spans + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + move_to = Control.move_to + new_line = Segment.line() + for last, (y, x, segments) in loop_last(self.spans): + yield move_to(x, y) + yield from segments + if not last: + yield new_line + + def __rich_repr__(self) -> rich.repr.Result: + yield [(y, x, "...") for y, x, _segments in self.spans] + + @rich.repr.auto(angular=True) class Compositor: """Responsible for storing information regarding the relative positions of Widgets and rendering them.""" def __init__(self) -> None: # A mapping of Widget on to its "render location" (absolute position / depth) - self.map: RenderRegionMap = {} + self.map: CompositorMap = {} # All widgets considered in the arrangement # Note this may be a superset of self.map.keys() as some widgets may be invisible for various reasons @@ -116,6 +144,42 @@ class Compositor: # The points in each line where the line bisects the left and right edges of the widget self._cuts: list[list[int]] | None = None + @classmethod + def _regions_to_spans( + cls, regions: Iterable[Region] + ) -> Iterable[tuple[int, int, int]]: + """Converts the regions to horizontal spans. Spans will be combined if they overlap + or are contiguous to produce optimal non-overlapping spans. + + Args: + regions (Iterable[Region]): An iterable of Regions. + + Returns: + Iterable[tuple[int, int, int]]: Yields tuples of (Y, X1, X2) + """ + inline_ranges: dict[int, list[tuple[int, int]]] = {} + for region_x, region_y, width, height in regions: + span = (region_x, region_x + width) + for y in range(region_y, region_y + height): + inline_ranges.setdefault(y, []).append(span) + + for y, ranges in sorted(inline_ranges.items()): + if len(ranges) == 1: + # Special case of 1 span + yield (y, *ranges[0]) + else: + ranges.sort() + x1, x2 = ranges[0] + for next_x1, next_x2 in ranges[1:]: + if next_x1 <= x2: + if next_x2 > x2: + x2 = next_x2 + else: + yield (y, x1, x2) + x1 = next_x1 + x2 = next_x2 + yield (y, x1, x2) + def __rich_repr__(self) -> rich.repr.Result: yield "size", self.size yield "widgets", self.widgets @@ -167,7 +231,7 @@ class Compositor: resized=resized_widgets, ) - def _arrange_root(self, root: Widget) -> tuple[RenderRegionMap, set[Widget]]: + def _arrange_root(self, root: Widget) -> tuple[CompositorMap, set[Widget]]: """Arrange a widgets children based on its layout attribute. Args: @@ -180,7 +244,7 @@ class Compositor: ORIGIN = Offset(0, 0) size = root.size - map: RenderRegionMap = {} + map: CompositorMap = {} widgets: set[Widget] = set() get_order = attrgetter("order") @@ -249,7 +313,7 @@ class Compositor: for chrome_widget, chrome_region in widget._arrange_scrollbars( container_size ): - map[chrome_widget] = RenderRegion( + map[chrome_widget] = MapGeometry( chrome_region + container_region.origin + layout_offset, order, clip, @@ -258,7 +322,7 @@ class Compositor: ) # Add the container widget, which will render a background - map[widget] = RenderRegion( + map[widget] = MapGeometry( region + layout_offset, order, clip, @@ -268,7 +332,7 @@ class Compositor: else: # Add the widget to the map - map[widget] = RenderRegion( + map[widget] = MapGeometry( region + layout_offset, order, clip, region.size, container_size ) @@ -338,8 +402,8 @@ class Compositor: return segment.style or Style.null() return Style.null() - def get_widget_region(self, widget: Widget) -> Region: - """Get the Region of a Widget contained in this Layout. + def find_widget(self, widget: Widget) -> MapGeometry: + """Get information regarding the relative position of a widget in the Compositor. Args: widget (Widget): The Widget in this layout you wish to know the Region of. @@ -348,11 +412,11 @@ class Compositor: NoWidget: If the Widget is not contained in this Layout. Returns: - Region: The Region of the Widget. + MapGeometry: Widget's composition information. """ try: - region, *_ = self.map[widget] + region = self.map[widget] except KeyError: raise errors.NoWidget("Widget is not in layout") else: @@ -452,11 +516,7 @@ class Compositor: ] return segment_lines - def render( - self, - *, - crop: Region | None = None, - ) -> SegmentLines: + def render(self, regions: list[Region] | None = None) -> RenderableType: """Render a layout. Args: @@ -467,8 +527,15 @@ class Compositor: """ width, height = self.size screen_region = Region(0, 0, width, height) - - crop_region = crop.intersection(screen_region) if crop else screen_region + if regions: + # Create a crop regions that surrounds all updates + crop = Region.from_union(regions).intersection(screen_region) + spans = list(self._regions_to_spans(regions)) + is_rendered_line = {y for y, _, _ in spans}.__contains__ + else: + crop = screen_region + spans = [] + is_rendered_line = lambda y: True _Segment = Segment divide = _Segment.divide @@ -480,9 +547,8 @@ class Compositor: "Callable[[list[int]], dict[int, list[Segment] | None]]", dict.fromkeys ) # A mapping of cut index to a list of segments for each line - chops: list[dict[int, list[Segment] | None]] = [ - fromkeys(cut_set) for cut_set in cuts - ] + chops: list[dict[int, list[Segment] | None]] + chops = [fromkeys(cut_set) for cut_set in cuts] # Go through all the renders in reverse order and fill buckets with no render renders = self._get_renders(crop) @@ -492,6 +558,8 @@ class Compositor: render_region = intersection(region, clip) for y, line in zip(render_region.y_range, lines): + if not is_rendered_line(y): + continue first_cut, last_cut = render_region.x_extents final_cuts = [cut for cut in cuts[y] if (last_cut >= cut >= first_cut)] @@ -501,6 +569,7 @@ class Compositor: else: render_x = render_region.x relative_cuts = [cut - render_x for cut in final_cuts] + # print(relative_cuts) _, *cut_segments = divide(line, relative_cuts) # Since we are painting front to back, the first segments for a cut "wins" @@ -509,24 +578,25 @@ class Compositor: if chops_line[cut] is None: chops_line[cut] = segments - # Assemble the cut renders in to lists of segments - crop_x, crop_y, crop_x2, crop_y2 = crop_region.corners - render_lines = self._assemble_chops(chops[crop_y:crop_y2]) - - if crop is not None and (crop_x, crop_x2) != (0, width): - render_lines = [ - line_crop(line, crop_x, crop_x2) if line else line - for line in render_lines + if regions: + crop_y, crop_y2 = crop.y_extents + render_lines = self._assemble_chops(chops[crop_y:crop_y2]) + render_spans = [ + (y, x1, line_crop(render_lines[y - crop_y], x1, x2)) + for y, x1, x2 in spans ] + return SpansUpdate(render_spans) - return SegmentLines(render_lines, new_lines=True) + else: + render_lines = self._assemble_chops(chops) + return LayoutUpdate(render_lines, screen_region) def __rich_console__( self, console: Console, options: ConsoleOptions ) -> RenderResult: yield self.render() - def update_widget(self, console: Console, widget: Widget) -> LayoutUpdate | None: + def update_widgets(self, widgets: set[Widget]) -> RenderableType | None: """Update a given widget in the composition. Args: @@ -536,14 +606,12 @@ class Compositor: Returns: LayoutUpdate | None: A renderable or None if nothing to render. """ - if widget not in self.regions: - return None - region, clip = self.regions[widget] - if not region: - return None - update_region = region.intersection(clip) - if not update_region: - return None - update_lines = self.render(crop=update_region).lines - update = LayoutUpdate(update_lines, update_region) + regions: list[Region] = [] + add_region = regions.append + for widget in self.regions.keys() & widgets: + region, clip = self.regions[widget] + update_region = region.intersection(clip) + if update_region: + add_region(update_region) + update = self.render(regions or None) return update diff --git a/src/textual/_layout.py b/src/textual/_layout.py index ea459a3e0..6ba50370a 100644 --- a/src/textual/_layout.py +++ b/src/textual/_layout.py @@ -24,6 +24,9 @@ class Layout(ABC): name: ClassVar[str] = "" + def __repr__(self) -> str: + return f"<{self.name}>" + @abstractmethod def arrange( self, parent: Widget, size: Size, scroll: Offset diff --git a/src/textual/_region_group.py b/src/textual/_region_group.py deleted file mode 100644 index 096007987..000000000 --- a/src/textual/_region_group.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from operator import attrgetter -from typing import NamedTuple, Iterable - -from .geometry import Region - - -class InlineRange(NamedTuple): - """Represents a region on a single line.""" - - line_index: int - start: int - end: int - - -def regions_to_ranges(regions: Iterable[Region]) -> Iterable[InlineRange]: - """Converts the regions to non-overlapping horizontal strips, where each strip - represents the region on a single line. Combining the resulting strips therefore - results in a shape identical to the combined original regions. - - Args: - regions (Iterable[Region]): An iterable of Regions. - - Returns: - Iterable[InlineRange]: Yields InlineRange objects representing the content on - a single line, with overlaps removed. - """ - inline_ranges: dict[int, list[InlineRange]] = defaultdict(list) - for region_x, region_y, width, height in regions: - for y in range(region_y, region_y + height): - inline_ranges[y].append( - InlineRange(line_index=y, start=region_x, end=region_x + width - 1) - ) - - get_start = attrgetter("start") - for line_index, ranges in inline_ranges.items(): - sorted_ranges = iter(sorted(ranges, key=get_start)) - _, start, end = next(sorted_ranges) - for next_line_index, next_start, next_end in sorted_ranges: - if next_start <= end + 1: - end = max(end, next_end) - else: - yield InlineRange(line_index, start, end) - start = next_start - end = next_end - yield InlineRange(line_index, start, end) diff --git a/src/textual/app.py b/src/textual/app.py index 63c6cfa7c..05c743418 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -33,6 +33,7 @@ from rich.measure import Measurement from rich.protocol import is_renderable from rich.screen import Screen as ScreenRenderable from rich.segment import Segments +from rich.style import Style from rich.traceback import Traceback from . import actions @@ -143,6 +144,9 @@ class App(Generic[ReturnType], DOMNode): self.driver_class = driver_class or self.get_driver_class() self._title = title self._screen_stack: list[Screen] = [] + self._sync_available = ( + os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal" and not WINDOWS + ) self.focused: Widget | None = None self.mouse_over: Widget | None = None @@ -478,7 +482,7 @@ class App(Generic[ReturnType], DOMNode): self.stylesheet.update(self) self.screen.refresh(layout=True) - def render(self) -> RenderableType: + def render(self, styles: Style) -> RenderableType: return "" def query(self, selector: str | None = None) -> DOMQuery: @@ -639,6 +643,7 @@ class App(Generic[ReturnType], DOMNode): def fatal_error(self) -> None: """Exits the app after an unhandled exception.""" + self.console.bell() traceback = Traceback( show_locals=True, width=None, locals_max_length=5, suppress=[rich] ) @@ -806,20 +811,19 @@ class App(Generic[ReturnType], DOMNode): def refresh(self, *, repaint: bool = True, layout: bool = False) -> None: if not self._running: return - sync_available = ( - os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal" and not WINDOWS - ) if not self._closed: console = self.console try: - if sync_available: + if self._sync_available: console.file.write("\x1bP=1s\x1b\\") console.print( ScreenRenderable( - Control.home(), self.screen._compositor, Control.home() + Control.home(), + self.screen._compositor, + Control.home(), ) ) - if sync_available: + if self._sync_available: console.file.write("\x1bP=2s\x1b\\") console.file.flush() except Exception as error: @@ -942,14 +946,14 @@ class App(Generic[ReturnType], DOMNode): action_target = default_namespace or self action_name = target - log("action", action) + log("", action) await self.dispatch_action(action_target, action_name, params) async def dispatch_action( self, namespace: object, action_name: str, params: Any ) -> None: log( - "dispatch_action", + "", namespace=namespace, action_name=action_name, params=params, diff --git a/src/textual/color.py b/src/textual/color.py index 0a2521db3..fd7fd0bd8 100644 --- a/src/textual/color.py +++ b/src/textual/color.py @@ -23,7 +23,7 @@ from rich.color import Color as RichColor from rich.style import Style from rich.text import Text - +from textual.suggestions import get_suggestion from ._color_constants import COLOR_NAME_TO_RGB from .geometry import clamp @@ -77,6 +77,17 @@ split_pairs4: Callable[[str], tuple[str, str, str, str]] = itemgetter( class ColorParseError(Exception): """A color failed to parse""" + def __init__(self, message: str, suggested_color: str | None = None): + """ + Creates a new ColorParseError + + Args: + message (str): the error message + suggested_color (str | None): a close color we can suggest. Defaults to None. + """ + super().__init__(message) + self.suggested_color = suggested_color + @rich.repr.auto class Color(NamedTuple): @@ -271,7 +282,14 @@ class Color(NamedTuple): return cls(*color_from_name) color_match = RE_COLOR.match(color_text) if color_match is None: - raise ColorParseError(f"failed to parse {color_text!r} as a color") + error_message = f"failed to parse {color_text!r} as a color" + suggested_color = None + if not color_text.startswith("#") and not color_text.startswith("rgb"): + # Seems like we tried to use a color name: let's try to find one that is close enough: + suggested_color = get_suggestion(color_text, COLOR_NAME_TO_RGB.keys()) + if suggested_color: + error_message += f"; did you mean '{suggested_color}'?" + raise ColorParseError(error_message, suggested_color) ( rgb_hex_triple, rgb_hex_quad, diff --git a/src/textual/css/_help_renderables.py b/src/textual/css/_help_renderables.py index 3a9d2e5ab..184256687 100644 --- a/src/textual/css/_help_renderables.py +++ b/src/textual/css/_help_renderables.py @@ -70,13 +70,13 @@ class HelpText: Attributes: summary (str): A succinct summary of the issue. - bullets (Iterable[Bullet]): Bullet points which provide additional - context around the issue. These are rendered below the summary. + bullets (Iterable[Bullet] | None): Bullet points which provide additional + context around the issue. These are rendered below the summary. Defaults to None. """ - def __init__(self, summary: str, *, bullets: Iterable[Bullet]) -> None: + def __init__(self, summary: str, *, bullets: Iterable[Bullet] = None) -> None: self.summary = summary - self.bullets = bullets + self.bullets = bullets or [] def __rich_console__( self, console: Console, options: ConsoleOptions diff --git a/src/textual/css/_help_text.py b/src/textual/css/_help_text.py index dc6cf9ff5..b3ed0f3ec 100644 --- a/src/textual/css/_help_text.py +++ b/src/textual/css/_help_text.py @@ -4,6 +4,7 @@ import sys from dataclasses import dataclass from typing import Iterable +from textual.color import ColorParseError from textual.css._help_renderables import Example, Bullet, HelpText from textual.css.constants import ( VALID_BORDER, @@ -144,13 +145,13 @@ def property_invalid_value_help_text( HelpText: Renderable for displaying the help text for this property """ property_name = _contextualize_property_name(property_name, context) - bullets = [] + summary = f"Invalid CSS property [i]{property_name}[/]" if suggested_property_name: suggested_property_name = _contextualize_property_name( suggested_property_name, context ) - bullets.append(Bullet(f'Did you mean "{suggested_property_name}"?')) - return HelpText(f"Invalid CSS property [i]{property_name}[/]", bullets=bullets) + summary += f'. Did you mean "{suggested_property_name}"?' + return HelpText(summary) def spacing_wrong_number_of_values_help_text( @@ -303,6 +304,8 @@ def string_enum_help_text( def color_property_help_text( property_name: str, context: StylingContext, + *, + error: Exception = None, ) -> HelpText: """Help text to show when the user supplies an invalid value for a color property. For example, an unparseable color string. @@ -310,13 +313,20 @@ def color_property_help_text( Args: property_name (str): The name of the property context (StylingContext | None): The context the property is being used in. + error (ColorParseError | None): The error that caused this help text to be displayed. Defaults to None. Returns: HelpText: Renderable for displaying the help text for this property """ property_name = _contextualize_property_name(property_name, context) + summary = f"Invalid value for the [i]{property_name}[/] property" + suggested_color = ( + error.suggested_color if error and isinstance(error, ColorParseError) else None + ) + if suggested_color: + summary += f'. Did you mean "{suggested_color}"?' return HelpText( - summary=f"Invalid value for the [i]{property_name}[/] property", + summary=summary, bullets=[ Bullet( f"The [i]{property_name}[/] property can only be set to a valid color" diff --git a/src/textual/css/_style_properties.py b/src/textual/css/_style_properties.py index 19f402bb4..b7a161c71 100644 --- a/src/textual/css/_style_properties.py +++ b/src/textual/css/_style_properties.py @@ -782,10 +782,12 @@ class ColorProperty: elif isinstance(color, str): try: parsed_color = Color.parse(color) - except ColorParseError: + except ColorParseError as error: raise StyleValueError( f"Invalid color value '{color}'", - help_text=color_property_help_text(self.name, context="inline"), + help_text=color_property_help_text( + self.name, context="inline", error=error + ), ) if obj.set_rule(self.name, parsed_color): obj.refresh() diff --git a/src/textual/css/_styles_builder.py b/src/textual/css/_styles_builder.py index 41ceea9ac..79935716e 100644 --- a/src/textual/css/_styles_builder.py +++ b/src/textual/css/_styles_builder.py @@ -572,9 +572,11 @@ class StylesBuilder: elif token.name in ("color", "token"): try: color = Color.parse(token.value) - except Exception: + except Exception as error: self.error( - name, token, color_property_help_text(name, context="css") + name, + token, + color_property_help_text(name, context="css", error=error), ) else: self.error(name, token, color_property_help_text(name, context="css")) diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index 2a7471cd0..681ed797d 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -241,7 +241,13 @@ class StylesBase(ABC): Returns: Spacing: Space around widget. """ - spacing = Spacing() + self.padding + self.border.spacing + spacing = self.padding + self.border.spacing + return spacing + + @property + def content_gutter(self) -> Spacing: + """The spacing that surrounds the content area of the widget.""" + spacing = self.padding + self.border.spacing + self.margin return spacing @abstractmethod diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index d4831998b..ebeb96970 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -7,7 +7,7 @@ from pathlib import Path, PurePath from typing import cast, Iterable import rich.repr -from rich.console import RenderableType, Console, ConsoleOptions +from rich.console import RenderableType, RenderResult, Console, ConsoleOptions from rich.highlighter import ReprHighlighter from rich.markup import render from rich.padding import Padding @@ -68,10 +68,10 @@ class StylesheetErrors: def __rich_console__( self, console: Console, options: ConsoleOptions - ) -> RenderableType: + ) -> RenderResult: error_count = 0 for rule in self.rules: - for is_last, (token, message) in loop_last(rule.errors): + for token, message in rule.errors: error_count += 1 if token.path: @@ -297,7 +297,6 @@ class Stylesheet: for name, specificity_rules in rule_attributes.items() }, ) - self.replace_rules(node, node_rules, animate=animate) @classmethod @@ -363,8 +362,9 @@ class Stylesheet: setattr(base_styles, key, new_value) else: # Not animated, so we apply the rules directly + get_rule = rules.get for key in modified_rule_keys: - setattr(base_styles, key, rules.get(key)) + setattr(base_styles, key, get_rule(key)) def update(self, root: DOMNode, animate: bool = False) -> None: """Update a node and its children.""" diff --git a/src/textual/events.py b/src/textual/events.py index 20bf7033a..aa2437078 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -395,9 +395,9 @@ class Blur(Event, bubble=False): pass -class DescendantFocus(Event, bubble=True): +class DescendantFocus(Event, verbosity=2, bubble=True): pass -class DescendantBlur(Event, bubble=True): +class DescendantBlur(Event, verbosity=2, bubble=True): pass diff --git a/src/textual/geometry.py b/src/textual/geometry.py index 324707178..3973e3eac 100644 --- a/src/textual/geometry.py +++ b/src/textual/geometry.py @@ -6,7 +6,7 @@ Functions and classes to manage terminal geometry (anything involving coordinate from __future__ import annotations -from typing import Any, cast, NamedTuple, Tuple, Union, TypeVar +from typing import Any, cast, Iterable, NamedTuple, Tuple, Union, TypeVar SpacingDimensions = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int, int]] @@ -181,6 +181,24 @@ class Region(NamedTuple): width: int = 0 height: int = 0 + @classmethod + def from_union(cls, regions: list[Region]) -> Region: + """Create a Region from the union of other regions. + + Args: + regions (Iterable[Region]): One or more regions. + + Returns: + Region: A Region that encloses all other regions. + """ + if not regions: + raise ValueError("At least one region expected") + min_x = min([region.x for region in regions]) + max_x = max([x + width for x, _y, width, _height in regions]) + min_y = min([region.y for region in regions]) + max_y = max([y + height for _x, y, _width, height in regions]) + return cls(min_x, min_y, max_x - min_x, max_y - min_y) + @classmethod def from_corners(cls, x1: int, y1: int, x2: int, y2: int) -> Region: """Construct a Region form the top left and bottom right corners. @@ -257,6 +275,24 @@ class Region(NamedTuple): """Get the start point of the region.""" return Offset(self.x, self.y) + @property + def bottom_left(self) -> Offset: + """Bottom left offset of the region.""" + x, y, _width, height = self + return Offset(x, y + height) + + @property + def top_right(self) -> Offset: + """Top right offset of the region.""" + x, y, width, _height = self + return Offset(x + width, y) + + @property + def bottom_right(self) -> Offset: + """Bottom right of the region.""" + x, y, width, height = self + return Offset(x + width, y + height) + @property def size(self) -> Size: """Get the size of the region.""" @@ -274,17 +310,17 @@ class Region(NamedTuple): @property def x_range(self) -> range: - """A range object for X coordinates""" + """A range object for X coordinates.""" return range(self.x, self.x + self.width) @property def y_range(self) -> range: - """A range object for Y coordinates""" + """A range object for Y coordinates.""" return range(self.y, self.y + self.height) @property def reset_origin(self) -> Region: - """An region of the same size at the origin.""" + """An region of the same size at (0, 0).""" _, _, width, height = self return Region(0, 0, width, height) diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 66ed98721..9fefc5ec3 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -149,8 +149,11 @@ class MessagePump: callback: TimerCallback = None, *, name: str | None = None, + pause: bool = False, ) -> Timer: - timer = Timer(self, delay, self, name=name, callback=callback, repeat=0) + timer = Timer( + self, delay, self, name=name, callback=callback, repeat=0, pause=pause + ) self._child_tasks.add(timer.start()) return timer @@ -161,9 +164,16 @@ class MessagePump: *, name: str | None = None, repeat: int = 0, + pause: bool = False, ): timer = Timer( - self, interval, self, name=name, callback=callback, repeat=repeat or None + self, + interval, + self, + name=name, + callback=callback, + repeat=repeat or None, + pause=pause, ) self._child_tasks.add(timer.start()) return timer diff --git a/src/textual/screen.py b/src/textual/screen.py index a575018da..51de850ff 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -8,7 +8,7 @@ from rich.style import Style from . import events, messages, errors from .geometry import Offset, Region -from ._compositor import Compositor +from ._compositor import Compositor, MapGeometry from .reactive import Reactive from .widget import Widget @@ -18,14 +18,14 @@ class Screen(Widget): """A widget for the root of the app.""" CSS = """ - + Screen { layout: vertical; overflow-y: auto; background: $surface; color: $text-surface; } - + """ dark = Reactive(False) @@ -33,13 +33,13 @@ class Screen(Widget): def __init__(self, name: str | None = None, id: str | None = None) -> None: super().__init__(name=name, id=id) self._compositor = Compositor() - self._dirty_widgets: list[Widget] = [] + self._dirty_widgets: set[Widget] = set() def watch_dark(self, dark: bool) -> None: pass - def render(self) -> RenderableType: - return self.app.render() + def render(self, style: Style) -> RenderableType: + return self.app.render(style) def get_offset(self, widget: Widget) -> Offset: """Get the absolute offset of a given Widget. @@ -76,7 +76,7 @@ class Screen(Widget): """ return self._compositor.get_style_at(x, y) - def get_widget_region(self, widget: Widget) -> Region: + def find_widget(self, widget: Widget) -> MapGeometry: """Get the screen region of a Widget. Args: @@ -85,24 +85,32 @@ class Screen(Widget): Returns: Region: Region relative to screen. """ - return self._compositor.get_widget_region(widget) + return self._compositor.find_widget(widget) def on_idle(self, event: events.Idle) -> None: # Check for any widgets marked as 'dirty' (needs a repaint) if self._dirty_widgets: - for widget in self._dirty_widgets: - # Repaint widgets - # TODO: Combine these in to a single update. - display_update = self._compositor.update_widget(self.console, widget) - if display_update is not None: - self.app.display(display_update) - # Reset dirty list + self._update_timer.resume() + + def _on_update(self) -> None: + """Called by the _update_timer.""" + + # Render widgets together + if self._dirty_widgets: + self.log(dirty=self._dirty_widgets) + display_update = self._compositor.update_widgets(self._dirty_widgets) + if display_update is not None: + self.app.display(display_update) self._dirty_widgets.clear() + self._update_timer.pause() def refresh_layout(self) -> None: """Refresh the layout (can change size and positions of widgets).""" if not self.size: return + # This paint the entire screen, so replaces the batched dirty widgets + self._update_timer.pause() + self._dirty_widgets.clear() try: hidden, shown, resized = self._compositor.reflow(self, self.size) @@ -133,30 +141,31 @@ class Screen(Widget): self.app.on_exception(error) return self.app.refresh() - self._dirty_widgets.clear() async def handle_update(self, message: messages.Update) -> None: message.stop() widget = message.widget assert isinstance(widget, Widget) - self._dirty_widgets.append(widget) + self._dirty_widgets.add(widget) self.check_idle() async def handle_layout(self, message: messages.Layout) -> None: message.stop() self.refresh_layout() + def on_mount(self, event: events.Mount) -> None: + self._update_timer = self.set_interval(1 / 20, self._on_update, pause=True) + async def on_resize(self, event: events.Resize) -> None: self.size_updated(event.size, event.virtual_size, event.container_size) self.refresh_layout() event.stop() async def _on_mouse_move(self, event: events.MouseMove) -> None: - try: if self.app.mouse_captured: widget = self.app.mouse_captured - region = self.get_widget_region(widget) + region = self.find_widget(widget).region else: widget, region = self.get_widget_at(event.x, event.y) except errors.NoWidget: @@ -195,7 +204,7 @@ class Screen(Widget): try: if self.app.mouse_captured: widget = self.app.mouse_captured - region = self.get_widget_region(widget) + region = self.find_widget(widget).region else: widget, region = self.get_widget_at(event.x, event.y) except errors.NoWidget: diff --git a/src/textual/scrollbar.py b/src/textual/scrollbar.py index 5633f0d0e..bcf3989aa 100644 --- a/src/textual/scrollbar.py +++ b/src/textual/scrollbar.py @@ -205,9 +205,9 @@ class ScrollBar(Widget): yield "window_size", self.window_size yield "position", self.position - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: styles = self.parent.styles - style = Style( + scrollbar_style = Style( bgcolor=( styles.scrollbar_background_hover.rich_color if self.mouse_over @@ -224,12 +224,9 @@ class ScrollBar(Widget): window_size=self.window_size, position=self.position, vertical=self.vertical, - style=style, + style=scrollbar_style, ) - async def on_event(self, event) -> None: - await super().on_event(event) - async def on_enter(self, event: events.Enter) -> None: self.mouse_over = True @@ -284,7 +281,6 @@ class ScrollBar(Widget): if __name__ == "__main__": from rich.console import Console - from rich.segment import Segments console = Console() bar = ScrollBarRender() diff --git a/src/textual/widget.py b/src/textual/widget.py index e46d9bb76..09b315eb0 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -26,6 +26,7 @@ from .box_model import BoxModel, get_box_model from .color import Color from ._context import active_app from ._types import Lines +from .css.styles import Styles from .dom import DOMNode from .geometry import clamp, Offset, Region, Size from .layouts.vertical import VerticalLayout @@ -83,6 +84,7 @@ class Widget(DOMNode): self._virtual_size = Size(0, 0) self._container_size = Size(0, 0) self._layout_required = False + self._repaint_required = False self._default_layout = VerticalLayout() self._animate: BoundAnimator | None = None self._reactive_watches: dict[str, Callable] = {} @@ -159,7 +161,7 @@ class Widget(DOMNode): int: The optimal width of the content. """ console = self.app.console - renderable = self.render() + renderable = self.render(self.styles.rich_style) measurement = Measurement.get(console, console.options, renderable) return measurement.maximum @@ -176,7 +178,7 @@ class Widget(DOMNode): Returns: int: The height of the content. """ - renderable = self.render() + renderable = self.render(self.styles.rich_style) options = self.console.options.update_width(width) segments = self.console.render(renderable, options) # Cheaper than counting the lines returned from render_lines! @@ -272,6 +274,8 @@ class Widget(DOMNode): self.show_horizontal_scrollbar = show_horizontal self.show_vertical_scrollbar = show_vertical + self.horizontal_scrollbar.display = show_horizontal + self.vertical_scrollbar.display = show_vertical @property def scrollbars_enabled(self) -> tuple[bool, bool]: @@ -298,17 +302,20 @@ class Widget(DOMNode): y: float | None = None, *, animate: bool = True, + speed: float | None = None, + duration: float | None = None, ) -> bool: """Scroll to a given (absolute) coordinate, optionally animating. Args: - scroll_x (int | None, optional): X coordinate (column) to scroll to, or ``None`` for no change. Defaults to None. - scroll_y (int | None, optional): Y coordinate (row) to scroll to, or ``None`` for no change. Defaults to None. + x (int | None, optional): X coordinate (column) to scroll to, or ``None`` for no change. Defaults to None. + y (int | None, optional): Y coordinate (row) to scroll to, or ``None`` for no change. Defaults to None. animate (bool, optional): Animate to new scroll position. Defaults to False. - """ - scrolled_x = False - scrolled_y = False + Returns: + bool: True if the scroll position changed, otherwise False. + """ + scrolled_x = scrolled_y = False if animate: # TODO: configure animation speed @@ -316,67 +323,151 @@ class Widget(DOMNode): self.scroll_target_x = x if x != self.scroll_x: self.animate( - "scroll_x", self.scroll_target_x, speed=80, easing="out_cubic" + "scroll_x", + self.scroll_target_x, + speed=speed, + duration=duration, + easing="out_cubic", ) scrolled_x = True if y is not None: self.scroll_target_y = y if y != self.scroll_y: self.animate( - "scroll_y", self.scroll_target_y, speed=80, easing="out_cubic" + "scroll_y", + self.scroll_target_y, + speed=speed, + duration=duration, + easing="out_cubic", ) scrolled_y = True else: if x is not None: + scroll_x = self.scroll_x self.scroll_target_x = self.scroll_x = x - if x != self.scroll_x: - scrolled_x = True + scrolled_x = scroll_x != self.scroll_x if y is not None: + scroll_y = self.scroll_y self.scroll_target_y = self.scroll_y = y - if y != self.scroll_y: - scrolled_y = True - self.refresh(repaint=False, layout=True) + scrolled_y = scroll_y != self.scroll_y + if scrolled_x or scrolled_y: + self.refresh(repaint=False, layout=True) + return scrolled_x or scrolled_y - def scroll_home(self, animate: bool = True) -> bool: + def scroll_relative( + self, + x: float | None = None, + y: float | None = None, + *, + animate: bool = True, + speed: float | None = None, + duration: float | None = None, + ) -> bool: + """Scroll relative to current position. + + Args: + x (int | None, optional): X distance (columns) to scroll, or ``None`` for no change. Defaults to None. + y (int | None, optional): Y distance (rows) to scroll, or ``None`` for no change. Defaults to None. + animate (bool, optional): Animate to new scroll position. Defaults to False. + + Returns: + bool: True if the scroll position changed, otherwise False. + """ + return self.scroll_to( + None if x is None else (self.scroll_x + x), + None if y is None else (self.scroll_y + y), + animate=animate, + speed=speed, + duration=duration, + ) + + def scroll_home(self, *, animate: bool = True) -> bool: return self.scroll_to(0, 0, animate=animate) - def scroll_end(self, animate: bool = True) -> bool: + def scroll_end(self, *, animate: bool = True) -> bool: return self.scroll_to(0, self.max_scroll_y, animate=animate) - def scroll_left(self, animate: bool = True) -> bool: + def scroll_left(self, *, animate: bool = True) -> bool: return self.scroll_to(x=self.scroll_target_x - 1, animate=animate) - def scroll_right(self, animate: bool = True) -> bool: + def scroll_right(self, *, animate: bool = True) -> bool: return self.scroll_to(x=self.scroll_target_x + 1, animate=animate) - def scroll_up(self, animate: bool = True) -> bool: + def scroll_up(self, *, animate: bool = True) -> bool: return self.scroll_to(y=self.scroll_target_y + 1, animate=animate) - def scroll_down(self, animate: bool = True) -> bool: + def scroll_down(self, *, animate: bool = True) -> bool: return self.scroll_to(y=self.scroll_target_y - 1, animate=animate) - def scroll_page_up(self, animate: bool = True) -> bool: + def scroll_page_up(self, *, animate: bool = True) -> bool: return self.scroll_to( y=self.scroll_target_y - self.container_size.height, animate=animate ) - def scroll_page_down(self, animate: bool = True) -> bool: + def scroll_page_down(self, *, animate: bool = True) -> bool: return self.scroll_to( y=self.scroll_target_y + self.container_size.height, animate=animate ) - def scroll_page_left(self, animate: bool = True) -> bool: + def scroll_page_left(self, *, animate: bool = True) -> bool: return self.scroll_to( x=self.scroll_target_x - self.container_size.width, animate=animate ) - def scroll_page_right(self, animate: bool = True) -> bool: + def scroll_page_right(self, *, animate: bool = True) -> bool: return self.scroll_to( x=self.scroll_target_x + self.container_size.width, animate=animate ) + def scroll_to_widget(self, widget: Widget, *, animate: bool = True) -> bool: + """Scroll so that a child widget is in the visible area. + + Args: + widget (Widget): A Widget in the children. + animate (bool, optional): True to animate, or False to jump. Defaults to True. + + Returns: + bool: True if the scroll position changed, otherwise False. + """ + screen = self.screen + try: + widget_geometry = screen.find_widget(widget) + container_geometry = screen.find_widget(self) + except errors.NoWidget: + return False + + widget_region = widget.content_region + widget_geometry.region.origin + container_region = self.content_region + container_geometry.region.origin + + if widget_region in container_region: + # Widget is visible, nothing to do + return False + + # We can either scroll so the widget is at the top of the container, or so that + # it is at the bottom. We want to pick which has the shortest distance + top_delta = widget_region.origin - container_region.origin + + bottom_delta = widget_region.origin - ( + container_region.origin + + Offset(0, container_region.height - widget_region.height) + ) + + if widget_region.width > container_region.width: + delta_x = top_delta.x + else: + delta_x = min(top_delta.x, bottom_delta.x, key=abs) + + if widget_region.height > container_region.height: + delta_y = top_delta.y + else: + delta_y = min(top_delta.y, bottom_delta.y, key=abs) + + return self.scroll_relative( + delta_x or None, delta_y or None, animate=animate, duration=0.2 + ) + def __init_subclass__( cls, can_focus: bool = True, can_focus_children: bool = True ) -> None: @@ -465,8 +556,7 @@ class Widget(DOMNode): Returns: RenderableType: A new renderable. """ - - renderable = self.render() + renderable = self.render(self.styles.rich_style) styles = self.styles parent_styles = self.parent.styles @@ -479,11 +569,12 @@ class Widget(DOMNode): horizontal, vertical = content_align renderable = Align(renderable, horizontal, vertical=vertical) + renderable = Padding(renderable, styles.padding) + renderable_text_style = parent_text_style + text_style if renderable_text_style: - renderable = Styled(renderable, renderable_text_style) - - renderable = Padding(renderable, styles.padding, style=renderable_text_style) + style = Style.from_color(text_style.color, text_style.bgcolor) + renderable = Styled(renderable, style) if styles.border: renderable = Border( @@ -517,14 +608,28 @@ class Widget(DOMNode): def container_size(self) -> Size: return self._container_size + @property + def content_region(self) -> Region: + """A region relative to the Widget origin that contains the content.""" + x, y = self.styles.content_gutter.top_left + width, height = self._container_size + return Region(x, y, width, height) + + @property + def content_offset(self) -> Offset: + """An offset from the Widget origin where the content begins.""" + x, y = self.styles.content_gutter.top_left + return Offset(x, y) + @property def virtual_size(self) -> Size: return self._virtual_size @property def region(self) -> Region: + """The region occupied by this widget, relative to the Screen.""" try: - return self.screen._compositor.get_widget_region(self) + return self.screen.find_widget(self).region except errors.NoWidget: return Region() @@ -632,8 +737,10 @@ class Widget(DOMNode): if self._dirty_regions: self._render_lines() if self.is_container: - self.horizontal_scrollbar.refresh() - self.vertical_scrollbar.refresh() + if self.show_horizontal_scrollbar: + self.horizontal_scrollbar.refresh() + if self.show_vertical_scrollbar: + self.vertical_scrollbar.refresh() lines = self._render_cache.lines[start:end] return lines @@ -669,21 +776,19 @@ class Widget(DOMNode): self._layout_required = True if repaint: self.set_dirty() + self._repaint_required = True self.check_idle() - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: """Get renderable for widget. + Args: + style (Styles): The Styles object for this Widget. + Returns: RenderableType: Any renderable """ - - # Default displays a pretty repr in the center of the screen - - if self.is_container: - return "" - - return self.css_identifier_styled + return "" if self.is_container else self.css_identifier_styled async def action(self, action: str, *params) -> None: await self.app.action(action, self) @@ -705,8 +810,9 @@ class Widget(DOMNode): if self.check_layout(): self._reset_check_layout() self.screen.post_message_no_wait(messages.Layout(self)) - elif self._dirty_regions: + elif self._repaint_required: self.emit_no_wait(messages.Update(self, self)) + self._repaint_required = False def focus(self) -> None: """Give input focus to this widget.""" @@ -766,6 +872,8 @@ class Widget(DOMNode): def on_descendant_focus(self, event: events.DescendantFocus) -> None: self.descendant_has_focus = True + if self.is_container and isinstance(event.sender, Widget): + self.scroll_to_widget(event.sender, animate=True) def on_descendant_blur(self, event: events.DescendantBlur) -> None: self.descendant_has_focus = False diff --git a/src/textual/widgets/_button.py b/src/textual/widgets/_button.py index a03e6003b..5565d9695 100644 --- a/src/textual/widgets/_button.py +++ b/src/textual/widgets/_button.py @@ -3,7 +3,8 @@ from __future__ import annotations from typing import cast from rich.console import RenderableType -from rich.text import Text +from rich.style import Style +from rich.text import Text, TextType from .. import events from ..message import Message @@ -24,8 +25,7 @@ class Button(Widget, can_focus=True): color: $text-primary; content-align: center middle; border: tall $primary-lighten-3; - - margin: 1; + margin: 1 0; text-style: bold; } @@ -48,7 +48,7 @@ class Button(Widget, can_focus=True): def __init__( self, - label: RenderableType | None = None, + label: TextType | None = None, disabled: bool = False, *, name: str | None = None, @@ -57,7 +57,11 @@ class Button(Widget, can_focus=True): ): super().__init__(name=name, id=id, classes=classes) - self.label = self.css_identifier_styled if label is None else label + if label is None: + label = self.css_identifier_styled + + self.label: Text = label + self.disabled = disabled if disabled: self.add_class("-disabled") @@ -70,8 +74,10 @@ class Button(Widget, can_focus=True): return Text.from_markup(label) return label - def render(self) -> RenderableType: - return self.label + def render(self, style: Style) -> RenderableType: + label = self.label.copy() + label.stylize(style) + return label async def on_click(self, event: events.Click) -> None: event.stop() diff --git a/src/textual/widgets/_footer.py b/src/textual/widgets/_footer.py index fba2efe1f..e111f0eb3 100644 --- a/src/textual/widgets/_footer.py +++ b/src/textual/widgets/_footer.py @@ -59,7 +59,7 @@ class Footer(Widget): text.append_text(key_text) return text - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: if self._key_text is None: self._key_text = self.make_key_text() return self._key_text diff --git a/src/textual/widgets/_header.py b/src/textual/widgets/_header.py index a475e0444..965184bdf 100644 --- a/src/textual/widgets/_header.py +++ b/src/textual/widgets/_header.py @@ -6,7 +6,7 @@ from logging import getLogger from rich.console import RenderableType from rich.panel import Panel from rich.repr import Result -from rich.style import StyleType +from rich.style import StyleType, Style from rich.table import Table from .. import events @@ -49,7 +49,7 @@ class Header(Widget): def get_clock(self) -> str: return datetime.now().time().strftime("%X") - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: header_table = Table.grid(padding=(0, 1), expand=True) header_table.style = self.style header_table.add_column(justify="left", ratio=0, width=8) diff --git a/src/textual/widgets/_placeholder.py b/src/textual/widgets/_placeholder.py index c43b3100d..e5d49e755 100644 --- a/src/textual/widgets/_placeholder.py +++ b/src/textual/widgets/_placeholder.py @@ -6,9 +6,8 @@ from rich.console import RenderableType from rich.panel import Panel from rich.pretty import Pretty import rich.repr +from rich.style import Style - -from .. import log from .. import events from ..reactive import Reactive from ..widget import Widget @@ -19,22 +18,23 @@ class Placeholder(Widget, can_focus=True): has_focus: Reactive[bool] = Reactive(False) mouse_over: Reactive[bool] = Reactive(False) - style: Reactive[str] = Reactive("") def __rich_repr__(self) -> rich.repr.Result: yield from super().__rich_repr__() yield "has_focus", self.has_focus, False yield "mouse_over", self.mouse_over, False - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: + # Apply colours only inside render_styled + # Pass the full RICH style object into `render` - not the `Styles` return Panel( Align.center( - Pretty(self, no_wrap=True, overflow="ellipsis"), vertical="middle" + Pretty(self, no_wrap=True, overflow="ellipsis"), + vertical="middle", ), title=self.__class__.__name__, border_style="green" if self.mouse_over else "blue", box=box.HEAVY if self.has_focus else box.ROUNDED, - style=self.style, ) async def on_focus(self, event: events.Focus) -> None: diff --git a/src/textual/widgets/_static.py b/src/textual/widgets/_static.py index 733986e2f..72863f602 100644 --- a/src/textual/widgets/_static.py +++ b/src/textual/widgets/_static.py @@ -1,9 +1,8 @@ from __future__ import annotations from rich.console import RenderableType -from rich.padding import Padding, PaddingDimensions -from rich.style import StyleType -from rich.styled import Styled +from rich.style import Style + from ..widget import Widget @@ -15,20 +14,13 @@ class Static(Widget): name: str | None = None, id: str | None = None, classes: str | None = None, - style: StyleType = "", - padding: PaddingDimensions = 0, ) -> None: super().__init__(name=name, id=id, classes=classes) self.renderable = renderable - self.style = style - self.padding = padding - def render(self) -> RenderableType: - renderable = self.renderable - if self.padding: - renderable = Padding(renderable, self.padding) - return Styled(renderable, self.style) + def render(self, style: Style) -> RenderableType: + return self.renderable - async def update(self, renderable: RenderableType) -> None: + def update(self, renderable: RenderableType) -> None: self.renderable = renderable - self.refresh() + self.refresh(layout=True) diff --git a/src/textual/widgets/_tree_control.py b/src/textual/widgets/_tree_control.py index a76030077..43dd9fc12 100644 --- a/src/textual/widgets/_tree_control.py +++ b/src/textual/widgets/_tree_control.py @@ -5,12 +5,11 @@ from typing import Generic, Iterator, NewType, TypeVar import rich.repr from rich.console import RenderableType +from rich.style import Style from rich.text import Text, TextType from rich.tree import Tree from rich.padding import PaddingDimensions -from .. import log -from .. import events from ..reactive import Reactive from .._types import MessageTarget from ..widget import Widget @@ -249,7 +248,7 @@ class TreeControl(Generic[NodeDataType], Widget): push(iter(node.children)) return None - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return self._tree def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType: diff --git a/src/textual/widgets/tabs.py b/src/textual/widgets/tabs.py index 4346816f6..8d351c510 100644 --- a/src/textual/widgets/tabs.py +++ b/src/textual/widgets/tabs.py @@ -330,7 +330,7 @@ class Tabs(Widget): """ return next((i for i, tab in enumerate(self.tabs) if tab.name == tab_name), 0) - def render(self) -> RenderableType: + def render(self, style: Style) -> RenderableType: return TabsRenderable( self.tabs, tab_padding=self.tab_padding, diff --git a/tests/css/test_stylesheet.py b/tests/css/test_stylesheet.py index e9042675b..9ef433a29 100644 --- a/tests/css/test_stylesheet.py +++ b/tests/css/test_stylesheet.py @@ -90,14 +90,48 @@ def test_did_you_mean_for_css_property_names( _, help_text = err.value.errors.rules[0].errors[0] # type: Any, HelpText displayed_css_property_name = css_property_name.replace("_", "-") - assert ( - help_text.summary == f"Invalid CSS property [i]{displayed_css_property_name}[/]" + expected_summary = f"Invalid CSS property [i]{displayed_css_property_name}[/]" + if expected_property_name_suggestion: + expected_summary += f'. Did you mean "{expected_property_name_suggestion}"?' + assert help_text.summary == expected_summary + + +@pytest.mark.parametrize( + "css_property_name,css_property_value,expected_color_suggestion", + [ + ["color", "blu", "blue"], + ["background", "chartruse", "chartreuse"], + ["tint", "ansi_whi", "ansi_white"], + ["scrollbar-color", "transprnt", "transparent"], + ["color", "xkcd", None], + ], +) +def test_did_you_mean_for_color_names( + css_property_name: str, css_property_value: str, expected_color_suggestion +): + stylesheet = Stylesheet() + css = """ + * { + border: blue; + ${PROPERTY}: ${VALUE}; + } + """.replace( + "${PROPERTY}", css_property_name + ).replace( + "${VALUE}", css_property_value ) - expected_bullets_length = 1 if expected_property_name_suggestion else 0 - assert len(help_text.bullets) == expected_bullets_length - if expected_property_name_suggestion is not None: - expected_suggestion_message = ( - f'Did you mean "{expected_property_name_suggestion}"?' - ) - assert help_text.bullets[0].markup == expected_suggestion_message + stylesheet.add_source(css) + with pytest.raises(StylesheetParseError) as err: + stylesheet.parse() + + _, help_text = err.value.errors.rules[0].errors[0] # type: Any, HelpText + displayed_css_property_name = css_property_name.replace("_", "-") + expected_error_summary = ( + f"Invalid value for the [i]{displayed_css_property_name}[/] property" + ) + + if expected_color_suggestion is not None: + expected_error_summary += f'. Did you mean "{expected_color_suggestion}"?' + + assert help_text.summary == expected_error_summary diff --git a/tests/test_region_group.py b/tests/test_compositor_regions_to_spans.py similarity index 50% rename from tests/test_region_group.py rename to tests/test_compositor_regions_to_spans.py index 930489a89..31de88ad5 100644 --- a/tests/test_region_group.py +++ b/tests/test_compositor_regions_to_spans.py @@ -1,44 +1,54 @@ -from textual._region_group import regions_to_ranges, InlineRange +from textual._compositor import Compositor from textual.geometry import Region def test_regions_to_ranges_no_regions(): - assert list(regions_to_ranges([])) == [] + assert list(Compositor._regions_to_spans([])) == [] def test_regions_to_ranges_single_region(): regions = [Region(0, 0, 3, 2)] - assert list(regions_to_ranges(regions)) == [InlineRange(0, 0, 2), InlineRange(1, 0, 2)] + assert list(Compositor._regions_to_spans(regions)) == [ + (0, 0, 3), + (1, 0, 3), + ] def test_regions_to_ranges_partially_overlapping_regions(): regions = [Region(0, 0, 2, 2), Region(1, 1, 2, 2)] - assert list(regions_to_ranges(regions)) == [ - InlineRange(0, 0, 1), InlineRange(1, 0, 2), InlineRange(2, 1, 2), + assert list(Compositor._regions_to_spans(regions)) == [ + (0, 0, 2), + (1, 0, 3), + (2, 1, 3), ] def test_regions_to_ranges_fully_overlapping_regions(): regions = [Region(1, 1, 3, 3), Region(2, 2, 1, 1), Region(0, 2, 3, 1)] - assert list(regions_to_ranges(regions)) == [ - InlineRange(1, 1, 3), InlineRange(2, 0, 3), InlineRange(3, 1, 3) + assert list(Compositor._regions_to_spans(regions)) == [ + (1, 1, 4), + (2, 0, 4), + (3, 1, 4), ] def test_regions_to_ranges_disjoint_regions_different_lines(): regions = [Region(0, 0, 2, 1), Region(2, 2, 2, 1)] - assert list(regions_to_ranges(regions)) == [InlineRange(0, 0, 1), InlineRange(2, 2, 3)] + assert list(Compositor._regions_to_spans(regions)) == [(0, 0, 2), (2, 2, 4)] def test_regions_to_ranges_disjoint_regions_same_line(): regions = [Region(0, 0, 1, 2), Region(2, 0, 1, 1)] - assert list(regions_to_ranges(regions)) == [ - InlineRange(0, 0, 0), InlineRange(0, 2, 2), InlineRange(1, 0, 0) + assert list(Compositor._regions_to_spans(regions)) == [ + (0, 0, 1), + (0, 2, 3), + (1, 0, 1), ] def test_regions_to_ranges_directly_adjacent_ranges_merged(): regions = [Region(0, 0, 1, 2), Region(1, 0, 1, 2)] - assert list(regions_to_ranges(regions)) == [ - InlineRange(0, 0, 1), InlineRange(1, 0, 1) + assert list(Compositor._regions_to_spans(regions)) == [ + (0, 0, 2), + (1, 0, 2), ] diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 72d5de462..b408a8ad4 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -114,6 +114,17 @@ def test_region_null(): assert not Region() +def test_region_from_union(): + with pytest.raises(ValueError): + Region.from_union([]) + regions = [ + Region(10, 20, 30, 40), + Region(15, 25, 5, 5), + Region(30, 25, 20, 10), + ] + assert Region.from_union(regions) == Region(10, 20, 40, 40) + + def test_region_from_origin(): assert Region.from_origin(Offset(3, 4), (5, 6)) == Region(3, 4, 5, 6) @@ -132,6 +143,18 @@ def test_region_origin(): assert Region(1, 2, 3, 4).origin == Offset(1, 2) +def test_region_bottom_left(): + assert Region(1, 2, 3, 4).bottom_left == Offset(1, 6) + + +def test_region_top_right(): + assert Region(1, 2, 3, 4).top_right == Offset(4, 2) + + +def test_region_bottom_right(): + assert Region(1, 2, 3, 4).bottom_right == Offset(4, 6) + + def test_region_add(): assert Region(1, 2, 3, 4) + (10, 20) == Region(11, 22, 3, 4) with pytest.raises(TypeError): diff --git a/tests/test_integration_layout.py b/tests/test_integration_layout.py index 141bb0f38..c41b22cb6 100644 --- a/tests/test_integration_layout.py +++ b/tests/test_integration_layout.py @@ -150,12 +150,10 @@ async def test_composition_of_vertical_container_with_children( expected_screen_size = Size(*screen_size) async with app.in_running_state(): - app.log_tree() - # root widget checks: root_widget = cast(Widget, app.get_child("root")) assert root_widget.size == expected_screen_size - root_widget_region = app.screen.get_widget_region(root_widget) + root_widget_region = app.screen.find_widget(root_widget).region assert root_widget_region == ( 0, 0, diff --git a/tests/test_integration_scrolling.py b/tests/test_integration_scrolling.py new file mode 100644 index 000000000..3ee0cab31 --- /dev/null +++ b/tests/test_integration_scrolling.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import sys +from typing import Sequence, cast + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # pragma: no cover + + +import pytest + +from sandbox.vertical_container import VerticalContainer +from tests.utilities.test_app import AppTest +from textual.app import ComposeResult +from textual.geometry import Size +from textual.widget import Widget +from textual.widgets import Placeholder + +SCREEN_SIZE = Size(100, 30) + + +@pytest.mark.skip("flaky test") +@pytest.mark.asyncio +@pytest.mark.integration_test # this is a slow test, we may want to skip them in some contexts +@pytest.mark.parametrize( + ( + "screen_size", + "placeholders_count", + "scroll_to_placeholder_id", + "scroll_to_animate", + "waiting_duration", + "last_screen_expected_placeholder_ids", + "last_screen_expected_out_of_viewport_placeholder_ids", + ), + ( + [SCREEN_SIZE, 10, None, None, 0.01, (0, 1, 2, 3, 4), "others"], + [SCREEN_SIZE, 10, "placeholder_3", False, 0.01, (0, 1, 2, 3, 4), "others"], + [SCREEN_SIZE, 10, "placeholder_5", False, 0.01, (1, 2, 3, 4, 5), "others"], + [SCREEN_SIZE, 10, "placeholder_7", False, 0.01, (3, 4, 5, 6, 7), "others"], + [SCREEN_SIZE, 10, "placeholder_9", False, 0.01, (5, 6, 7, 8, 9), "others"], + # N.B. Scroll duration is hard-coded to 0.2 in the `scroll_to_widget` method atm + # Waiting for this duration should allow us to see the scroll finished: + [SCREEN_SIZE, 10, "placeholder_9", True, 0.21, (5, 6, 7, 8, 9), "others"], + # After having waited for approximately half of the scrolling duration, we should + # see the middle Placeholders as we're scrolling towards the last of them. + # The state of the screen at this "halfway there" timing looks to not be deterministic though, + # depending on the environment - so let's only assert stuff for the middle placeholders + # and the first and last ones, but without being too specific about the others: + [SCREEN_SIZE, 10, "placeholder_9", True, 0.1, (5, 6, 7), (1, 2, 9)], + ), +) +async def test_scroll_to_widget( + screen_size: Size, + placeholders_count: int, + scroll_to_animate: bool | None, + scroll_to_placeholder_id: str | None, + waiting_duration: float | None, + last_screen_expected_placeholder_ids: Sequence[int], + last_screen_expected_out_of_viewport_placeholder_ids: Sequence[int] + | Literal["others"], +): + class MyTestApp(AppTest): + CSS = """ + Placeholder { + height: 5; /* minimal height to see the name of a Placeholder */ + } + """ + + def compose(self) -> ComposeResult: + placeholders = [ + Placeholder(id=f"placeholder_{i}", name=f"Placeholder #{i}") + for i in range(placeholders_count) + ] + + yield VerticalContainer(*placeholders, id="root") + + app = MyTestApp(size=screen_size, test_name="scroll_to_widget") + + async with app.in_running_state(waiting_duration_post_yield=waiting_duration or 0): + if scroll_to_placeholder_id: + target_widget_container = cast(Widget, app.query("#root").first()) + target_widget = cast( + Widget, app.query(f"#{scroll_to_placeholder_id}").first() + ) + target_widget_container.scroll_to_widget( + target_widget, animate=scroll_to_animate + ) + + last_display_capture = app.last_display_capture + + placeholders_visibility_by_id = { + id_: f"placeholder_{id_}" in last_display_capture + for id_ in range(placeholders_count) + } + + # Let's start by checking placeholders that should be visible: + for placeholder_id in last_screen_expected_placeholder_ids: + assert ( + placeholders_visibility_by_id[placeholder_id] is True + ), f"Placeholder '{placeholder_id}' should be visible but isn't" + + # Ok, now for placeholders that should *not* be visible: + if last_screen_expected_out_of_viewport_placeholder_ids == "others": + # We're simply going to check that all the placeholders that are not in + # `last_screen_expected_placeholder_ids` are not on the screen: + last_screen_expected_out_of_viewport_placeholder_ids = sorted( + tuple( + set(range(placeholders_count)) + - set(last_screen_expected_placeholder_ids) + ) + ) + for placeholder_id in last_screen_expected_out_of_viewport_placeholder_ids: + assert ( + placeholders_visibility_by_id[placeholder_id] is False + ), f"Placeholder '{placeholder_id}' should not be visible but is" diff --git a/tests/test_widget.py b/tests/test_widget.py index 341b7965d..b1c8e9f96 100644 --- a/tests/test_widget.py +++ b/tests/test_widget.py @@ -1,7 +1,5 @@ -from contextlib import nullcontext as does_not_raise -from decimal import Decimal - import pytest +from rich.style import Style from textual.app import App from textual.css.errors import StyleValueError @@ -41,7 +39,7 @@ def test_widget_content_width(): self.text = text super().__init__(id=id) - def render(self) -> str: + def render(self, style: Style) -> str: return self.text widget1 = TextWidget("foo", id="widget1") diff --git a/tests/utilities/test_app.py b/tests/utilities/test_app.py index 9bfb8d5d6..763be0ddc 100644 --- a/tests/utilities/test_app.py +++ b/tests/utilities/test_app.py @@ -4,9 +4,10 @@ import asyncio import contextlib import io from pathlib import Path -from typing import AsyncContextManager +from typing import AsyncContextManager, cast + +from rich.console import Console -from rich.console import Console, Capture from textual import events from textual.app import App, ReturnType, ComposeResult from textual.driver import Driver @@ -16,6 +17,9 @@ from textual.geometry import Size # N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc, # but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/ +# This value is also hard-coded in Textual's `App` class: +CLEAR_SCREEN_SEQUENCE = "\x1bP=1s\x1b\\" + class AppTest(App): def __init__( @@ -25,7 +29,7 @@ class AppTest(App): size: Size, log_verbosity: int = 2, ): - # will log in "/tests/test.[test name].log": + # Tests will log in "/tests/test.[test name].log": log_path = Path(__file__).parent.parent / f"test.{test_name}.log" super().__init__( driver_class=DriverTest, @@ -33,6 +37,11 @@ class AppTest(App): log_verbosity=log_verbosity, log_color_system="256", ) + + # We need this so the `CLEAR_SCREEN_SEQUENCE` is always sent for a screen refresh, + # whatever the environment: + self._sync_available = True + self._size = size self._console = ConsoleTest(width=size.width, height=size.height) self._error_console = ConsoleTest(width=size.width, height=size.height) @@ -49,16 +58,18 @@ class AppTest(App): def in_running_state( self, *, - initialisation_timeout: float = 0.1, - ) -> AsyncContextManager[Capture]: + waiting_duration_after_initialisation: float = 0.1, + waiting_duration_post_yield: float = 0, + ) -> AsyncContextManager: async def run_app() -> None: await self.process_messages() @contextlib.asynccontextmanager async def get_running_state_context_manager(): + self._set_active() run_task = asyncio.create_task(run_app()) timeout_before_yielding_task = asyncio.create_task( - asyncio.sleep(initialisation_timeout) + asyncio.sleep(waiting_duration_after_initialisation) ) done, pending = await asyncio.wait( ( @@ -69,10 +80,11 @@ class AppTest(App): ) if run_task in done or run_task not in pending: raise RuntimeError( - "TestApp is no longer return after its initialization period" + "TestApp is no longer running after its initialization period" ) - with self.console.capture() as capture: - yield capture + yield + if waiting_duration_post_yield: + await asyncio.sleep(waiting_duration_post_yield) assert not run_task.done() await self.shutdown() @@ -83,6 +95,18 @@ class AppTest(App): "Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`" ) + @property + def total_capture(self) -> str | None: + return self.console.file.getvalue() + + @property + def last_display_capture(self) -> str | None: + total_capture = self.total_capture + if not total_capture: + return None + last_display_start_index = total_capture.rindex(CLEAR_SCREEN_SEQUENCE) + return total_capture[last_display_start_index:] + @property def console(self) -> ConsoleTest: return self._console @@ -110,10 +134,18 @@ class ConsoleTest(Console): file=file, width=width, height=height, - force_terminal=True, + force_terminal=False, legacy_windows=False, ) + @property + def file(self) -> io.StringIO: + return cast(io.StringIO, self._file) + + @property + def is_dumb_terminal(self) -> bool: + return False + class DriverTest(Driver): def start_application_mode(self) -> None: