Merge branch 'css' of github.com:Textualize/textual into text-input

This commit is contained in:
Darren Burns
2022-05-13 10:27:39 +01:00
47 changed files with 835 additions and 310 deletions

View File

@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from rich.align import Align from rich.align import Align
from rich.style import Style
from textual.app import App from textual.app import App
from textual.widget import Widget from textual.widget import Widget
@@ -10,7 +11,7 @@ class Clock(Widget):
def on_mount(self): def on_mount(self):
self.set_interval(1, self.refresh) self.set_interval(1, self.refresh)
def render(self): def render(self, style: Style):
time = datetime.now().strftime("%c") time = datetime.now().strftime("%c")
return Align.center(time, vertical="middle") return Align.center(time, vertical="middle")

View File

@@ -1,4 +1,5 @@
from rich.panel import Panel from rich.panel import Panel
from rich.style import Style
from textual.app import App from textual.app import App
from textual.reactive import Reactive from textual.reactive import Reactive
@@ -9,7 +10,7 @@ class Hover(Widget):
mouse_over = Reactive(False) mouse_over = Reactive(False)
def render(self) -> Panel: def render(self, style: Style) -> Panel:
return Panel("Hello [b]World[/b]", style=("on red" if self.mouse_over else "")) return Panel("Hello [b]World[/b]", style=("on red" if self.mouse_over else ""))
def on_enter(self) -> None: def on_enter(self) -> None:

View File

@@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from rich.align import Align
from rich.console import RenderableType from rich.console import RenderableType
from rich.style import Style
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
@@ -53,12 +53,12 @@ lorem = Text.from_markup(
class TweetHeader(Widget): class TweetHeader(Widget):
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return Text("Lorem Impsum", justify="center") return Text("Lorem Impsum", justify="center")
class TweetBody(Widget): class TweetBody(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return lorem return lorem
@@ -67,22 +67,22 @@ class Tweet(Widget):
class OptionItem(Widget): class OptionItem(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("Option") return Text("Option")
class Error(Widget): class Error(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is an error message", justify="center") return Text("This is an error message", justify="center")
class Warning(Widget): class Warning(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is a warning message", justify="center") return Text("This is a warning message", justify="center")
class Success(Widget): class Success(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is a success message", justify="center") return Text("This is a success message", justify="center")

View File

@@ -1,11 +1,9 @@
from rich.console import Group
from rich.padding import Padding from rich.padding import Padding
from rich.style import Style
from rich.text import Text from rich.text import Text
from textual.app import App from textual.app import App
from textual.renderables.gradient import VerticalGradient from textual.renderables.gradient import VerticalGradient
from textual import events
from textual.widgets import Placeholder
from textual.widget import Widget from textual.widget import Widget
lorem = Text.from_markup( lorem = Text.from_markup(
@@ -15,12 +13,12 @@ lorem = Text.from_markup(
class Lorem(Widget): class Lorem(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Padding(lorem, 1) return Padding(lorem, 1)
class Background(Widget): class Background(Widget):
def render(self): def render(self, style: Style):
return VerticalGradient("#212121", "#212121") return VerticalGradient("#212121", "#212121")

View File

@@ -9,13 +9,14 @@ from decimal import Decimal
from rich.align import Align from rich.align import Align
from rich.console import Console, ConsoleOptions, RenderResult, RenderableType from rich.console import Console, ConsoleOptions, RenderResult, RenderableType
from rich.padding import Padding from rich.padding import Padding
from rich.style import Style
from rich.text import Text from rich.text import Text
from textual.app import App from textual.app import App
from textual.reactive import Reactive from textual.reactive import Reactive
from textual.views import GridView from textual.views import GridView
from textual.widget import Widget from textual.widget import Widget
from textual.widgets import Button, ButtonPressed from textual.widgets import Button
try: try:
from pyfiglet import Figlet from pyfiglet import Figlet
@@ -55,7 +56,7 @@ class Numbers(Widget):
value = Reactive("0") value = Reactive("0")
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
"""Build a Rich renderable to render the calculator display.""" """Build a Rich renderable to render the calculator display."""
return Padding( return Padding(
Align.right(FigletText(self.value), vertical="middle"), Align.right(FigletText(self.value), vertical="middle"),

View File

@@ -47,6 +47,9 @@ includes = "src"
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "auto"
testpaths = ["tests"] testpaths = ["tests"]
markers = [
"integration_test: marks tests as slow integration tests(deselect with '-m \"not integration_test\"')",
]
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]

View File

@@ -1,4 +1,4 @@
from rich.text import Text from rich.style import Style
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.widget import Widget from textual.widget import Widget
@@ -6,7 +6,7 @@ from textual.widgets import Static
class Thing(Widget): class Thing(Widget):
def render(self): def render(self, style: Style):
return "Hello, 3434 World.\n[b]Lorem impsum." return "Hello, 3434 World.\n[b]Lorem impsum."

View File

@@ -167,7 +167,6 @@ TweetBody {
OptionItem { OptionItem {
height: 3; height: 3;
background: $primary; background: $primary;
transition: background 100ms linear;
border-right: outer $primary-darken-2; border-right: outer $primary-darken-2;
border-left: hidden; border-left: hidden;
content-align: center middle; content-align: center middle;
@@ -224,4 +223,4 @@ Success {
.horizontal { .horizontal {
layout: horizontal layout: horizontal
} }

View File

@@ -1,4 +1,5 @@
from rich.console import RenderableType from rich.console import RenderableType
from rich.style import Style
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
@@ -50,12 +51,12 @@ lorem = Text.from_markup(
class TweetHeader(Widget): class TweetHeader(Widget):
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return Text("Lorem Impsum", justify="center") return Text("Lorem Impsum", justify="center")
class TweetBody(Widget): class TweetBody(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return lorem return lorem
@@ -64,22 +65,22 @@ class Tweet(Widget):
class OptionItem(Widget): class OptionItem(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("Option") return Text("Option")
class Error(Widget): class Error(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is an error message", justify="center") return Text("This is an error message", justify="center")
class Warning(Widget): class Warning(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is a warning message", justify="center") return Text("This is a warning message", justify="center")
class Success(Widget): class Success(Widget):
def render(self) -> Text: def render(self, style: Style) -> Text:
return Text("This is a success message", justify="center") return Text("This is a success message", justify="center")

View File

@@ -0,0 +1,8 @@
#foo {
text-style: underline;
background: rebeccapurple;
}
#foo:hover {
background: greenyellow;
}

View File

@@ -1,12 +1,13 @@
from rich.console import RenderableType from rich.console import RenderableType
from rich.panel import Panel from rich.panel import Panel
from rich.style import Style
from textual.app import App from textual.app import App
from textual.widget import Widget from textual.widget import Widget
class PanelWidget(Widget): class PanelWidget(Widget):
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return Panel("hello world!", title="Title") return Panel("hello world!", title="Title")

View File

@@ -0,0 +1,72 @@
from rich.console import RenderableType
from rich.text import Text
from textual.app import App, ComposeResult
from textual.widget import Widget
from textual.widgets import Placeholder
placeholders_count = 12
class VerticalContainer(Widget):
CSS = """
VerticalContainer {
layout: vertical;
overflow: hidden auto;
background: darkblue;
}
VerticalContainer Placeholder {
margin: 1 0;
height: 5;
border: solid lime;
align: center top;
}
"""
class Introduction(Widget):
CSS = """
Introduction {
background: indigo;
color: white;
height: 3;
padding: 1 0;
}
"""
def render(self) -> RenderableType:
return Text(
"Press keys 0 to 9 to scroll to the Placeholder with that ID.",
justify="center",
)
class MyTestApp(App):
def compose(self) -> ComposeResult:
placeholders = [
Placeholder(id=f"placeholder_{i}", name=f"Placeholder #{i}")
for i in range(placeholders_count)
]
yield VerticalContainer(Introduction(), *placeholders, id="root")
def on_mount(self):
self.bind("q", "quit")
self.bind("t", "tree")
for widget_index in range(placeholders_count):
self.bind(str(widget_index), f"scroll_to('placeholder_{widget_index}')")
def action_tree(self):
self.log(self.tree)
async def action_scroll_to(self, target_placeholder_id: str):
target_placeholder = self.query(f"#{target_placeholder_id}").first()
target_placeholder_container = self.query("#root").first()
target_placeholder_container.scroll_to_widget(target_placeholder, animate=True)
app = MyTestApp()
if __name__ == "__main__":
app.run()

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass
from rich.console import RenderableType from rich.console import RenderableType
from rich.padding import Padding from rich.padding import Padding
from rich.rule import Rule from rich.rule import Rule
from rich.style import Style
from textual import events from textual import events
from textual.app import App from textual.app import App
@@ -11,7 +12,7 @@ from textual.widgets.tabs import Tabs, Tab
class Hr(Widget): class Hr(Widget):
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return Rule() return Rule()
@@ -22,7 +23,7 @@ class Info(Widget):
super().__init__() super().__init__()
self.text = text self.text = text
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return Padding(f"{self.text}", pad=(0, 1)) return Padding(f"{self.text}", pad=(0, 1))
@@ -144,4 +145,5 @@ class BasicApp(App):
self.mount(example.widget) self.mount(example.widget)
BasicApp.run(css_path="tabs.scss", watch_css=True, log_path="textual.log") app = BasicApp(css_path="tabs.scss", watch_css=True, log_path="textual.log")
app.run()

View File

@@ -7,14 +7,20 @@ App.-show-focus *:focus {
background: green; background: green;
overflow: hidden auto; overflow: hidden auto;
border: heavy white; border: heavy white;
text-style: underline;
} }
#uber1:focus-within { #uber1:focus-within {
background: darkslateblue; background: darkslateblue;
} }
#child2 {
text-style: underline;
background: red;
}
.list-item { .list-item {
height: 10; height: 20;
color: #12a0; color: #12a0;
background: #ffffff00; background: #ffffff00;
} }

View File

@@ -25,7 +25,6 @@ class BasicApp(App):
first_child = Placeholder(id="child1", classes="list-item") first_child = Placeholder(id="child1", classes="list-item")
uber1 = Widget( uber1 = Widget(
first_child, first_child,
Placeholder(id="child1", classes="list-item"),
Placeholder(id="child2", classes="list-item"), Placeholder(id="child2", classes="list-item"),
Placeholder(id="child3", classes="list-item"), Placeholder(id="child3", classes="list-item"),
Placeholder(classes="list-item"), Placeholder(classes="list-item"),
@@ -33,6 +32,7 @@ class BasicApp(App):
Placeholder(classes="list-item"), Placeholder(classes="list-item"),
) )
self.mount(uber1=uber1) self.mount(uber1=uber1)
uber1.focus()
self.first_child = first_child self.first_child = first_child
self.uber = uber1 self.uber = uber1
@@ -50,9 +50,8 @@ class BasicApp(App):
def action_print(self): def action_print(self):
print( print(
"Printed using builtin [b blue]print[/] function:", "Focused widget is:",
self.screen.tree, self.focused,
sep=" - ",
) )
self.app.set_focus(None) self.app.set_focus(None)

View File

@@ -20,7 +20,7 @@ class VerticalContainer(Widget):
VerticalContainer Placeholder { VerticalContainer Placeholder {
margin: 1 0; margin: 1 0;
height: 3; height: 5;
border: solid lime; border: solid lime;
align: center top; align: center top;
} }
@@ -79,10 +79,10 @@ class MyTestApp(App):
placeholders = self.query("Placeholder") placeholders = self.query("Placeholder")
placeholders_count = len(placeholders) placeholders_count = len(placeholders)
placeholder = Placeholder( placeholder = Placeholder(
id=f"placeholder_{placeholders_count+1}", id=f"placeholder_{placeholders_count}",
name=f"Placeholder #{placeholders_count+1}", name=f"Placeholder #{placeholders_count}",
) )
root = self.query_one("#root") root = self.get_child("root")
root.mount(placeholder) root.mount(placeholder)
self.refresh(repaint=True, layout=True) self.refresh(repaint=True, layout=True)
self.refresh_css() self.refresh_css()

View File

@@ -15,12 +15,12 @@ from __future__ import annotations
from operator import attrgetter, itemgetter from operator import attrgetter, itemgetter
import sys import sys
from typing import cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING from typing import Callable, cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING
import rich.repr import rich.repr
from rich.console import Console, ConsoleOptions, RenderResult from rich.console import Console, ConsoleOptions, RenderResult, RenderableType
from rich.control import Control from rich.control import Control
from rich.segment import Segment, SegmentLines from rich.segment import Segment
from rich.style import Style from rich.style import Style
from . import errors from . import errors
@@ -50,17 +50,17 @@ class ReflowResult(NamedTuple):
resized: set[Widget] # Widgets that have been resized resized: set[Widget] # Widgets that have been resized
class RenderRegion(NamedTuple): class MapGeometry(NamedTuple):
"""Defines the absolute location of a Widget.""" """Defines the absolute location of a Widget."""
region: Region # The region occupied by the widget region: Region # The region occupied by the widget
order: tuple[int, ...] # A tuple of ints defining the painting order order: tuple[int, ...] # A tuple of ints defining the painting order
clip: Region # A region to clip the widget by (if a Widget is within a container) clip: Region # A region to clip the widget by (if a Widget is within a container)
virtual_size: Size # The virtual size (scrollable region) of a widget if it is a container virtual_size: Size # The virtual size (scrollable region) of a widget if it is a container
container_size: Size # The container size (area no occupied by scrollbars) container_size: Size # The container size (area not occupied by scrollbars)
RenderRegionMap: TypeAlias = "dict[Widget, RenderRegion]" CompositorMap: TypeAlias = "dict[Widget, MapGeometry]"
@rich.repr.auto @rich.repr.auto
@@ -78,6 +78,7 @@ class LayoutUpdate:
new_line = Segment.line() new_line = Segment.line()
move_to = Control.move_to move_to = Control.move_to
for last, (y, line) in loop_last(enumerate(self.lines, self.region.y)): for last, (y, line) in loop_last(enumerate(self.lines, self.region.y)):
yield Control.home()
yield move_to(x, y) yield move_to(x, y)
yield from line yield from line
if not last: if not last:
@@ -91,13 +92,40 @@ class LayoutUpdate:
yield "height", height yield "height", height
@rich.repr.auto
class SpansUpdate:
"""A renderable that applies updated spans to the screen."""
def __init__(self, spans: list[tuple[int, int, list[Segment]]]) -> None:
"""Apply spans, which consist of a tuple of (LINE, OFFSET, SEGMENTS)
Args:
spans (list[tuple[int, int, list[Segment]]]): A list of spans.
"""
self.spans = spans
def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
move_to = Control.move_to
new_line = Segment.line()
for last, (y, x, segments) in loop_last(self.spans):
yield move_to(x, y)
yield from segments
if not last:
yield new_line
def __rich_repr__(self) -> rich.repr.Result:
yield [(y, x, "...") for y, x, _segments in self.spans]
@rich.repr.auto(angular=True) @rich.repr.auto(angular=True)
class Compositor: class Compositor:
"""Responsible for storing information regarding the relative positions of Widgets and rendering them.""" """Responsible for storing information regarding the relative positions of Widgets and rendering them."""
def __init__(self) -> None: def __init__(self) -> None:
# A mapping of Widget on to its "render location" (absolute position / depth) # A mapping of Widget on to its "render location" (absolute position / depth)
self.map: RenderRegionMap = {} self.map: CompositorMap = {}
# All widgets considered in the arrangement # All widgets considered in the arrangement
# Note this may be a superset of self.map.keys() as some widgets may be invisible for various reasons # Note this may be a superset of self.map.keys() as some widgets may be invisible for various reasons
@@ -116,6 +144,42 @@ class Compositor:
# The points in each line where the line bisects the left and right edges of the widget # The points in each line where the line bisects the left and right edges of the widget
self._cuts: list[list[int]] | None = None self._cuts: list[list[int]] | None = None
@classmethod
def _regions_to_spans(
cls, regions: Iterable[Region]
) -> Iterable[tuple[int, int, int]]:
"""Converts the regions to horizontal spans. Spans will be combined if they overlap
or are contiguous to produce optimal non-overlapping spans.
Args:
regions (Iterable[Region]): An iterable of Regions.
Returns:
Iterable[tuple[int, int, int]]: Yields tuples of (Y, X1, X2)
"""
inline_ranges: dict[int, list[tuple[int, int]]] = {}
for region_x, region_y, width, height in regions:
span = (region_x, region_x + width)
for y in range(region_y, region_y + height):
inline_ranges.setdefault(y, []).append(span)
for y, ranges in sorted(inline_ranges.items()):
if len(ranges) == 1:
# Special case of 1 span
yield (y, *ranges[0])
else:
ranges.sort()
x1, x2 = ranges[0]
for next_x1, next_x2 in ranges[1:]:
if next_x1 <= x2:
if next_x2 > x2:
x2 = next_x2
else:
yield (y, x1, x2)
x1 = next_x1
x2 = next_x2
yield (y, x1, x2)
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
yield "size", self.size yield "size", self.size
yield "widgets", self.widgets yield "widgets", self.widgets
@@ -167,7 +231,7 @@ class Compositor:
resized=resized_widgets, resized=resized_widgets,
) )
def _arrange_root(self, root: Widget) -> tuple[RenderRegionMap, set[Widget]]: def _arrange_root(self, root: Widget) -> tuple[CompositorMap, set[Widget]]:
"""Arrange a widgets children based on its layout attribute. """Arrange a widgets children based on its layout attribute.
Args: Args:
@@ -180,7 +244,7 @@ class Compositor:
ORIGIN = Offset(0, 0) ORIGIN = Offset(0, 0)
size = root.size size = root.size
map: RenderRegionMap = {} map: CompositorMap = {}
widgets: set[Widget] = set() widgets: set[Widget] = set()
get_order = attrgetter("order") get_order = attrgetter("order")
@@ -249,7 +313,7 @@ class Compositor:
for chrome_widget, chrome_region in widget._arrange_scrollbars( for chrome_widget, chrome_region in widget._arrange_scrollbars(
container_size container_size
): ):
map[chrome_widget] = RenderRegion( map[chrome_widget] = MapGeometry(
chrome_region + container_region.origin + layout_offset, chrome_region + container_region.origin + layout_offset,
order, order,
clip, clip,
@@ -258,7 +322,7 @@ class Compositor:
) )
# Add the container widget, which will render a background # Add the container widget, which will render a background
map[widget] = RenderRegion( map[widget] = MapGeometry(
region + layout_offset, region + layout_offset,
order, order,
clip, clip,
@@ -268,7 +332,7 @@ class Compositor:
else: else:
# Add the widget to the map # Add the widget to the map
map[widget] = RenderRegion( map[widget] = MapGeometry(
region + layout_offset, order, clip, region.size, container_size region + layout_offset, order, clip, region.size, container_size
) )
@@ -338,8 +402,8 @@ class Compositor:
return segment.style or Style.null() return segment.style or Style.null()
return Style.null() return Style.null()
def get_widget_region(self, widget: Widget) -> Region: def find_widget(self, widget: Widget) -> MapGeometry:
"""Get the Region of a Widget contained in this Layout. """Get information regarding the relative position of a widget in the Compositor.
Args: Args:
widget (Widget): The Widget in this layout you wish to know the Region of. widget (Widget): The Widget in this layout you wish to know the Region of.
@@ -348,11 +412,11 @@ class Compositor:
NoWidget: If the Widget is not contained in this Layout. NoWidget: If the Widget is not contained in this Layout.
Returns: Returns:
Region: The Region of the Widget. MapGeometry: Widget's composition information.
""" """
try: try:
region, *_ = self.map[widget] region = self.map[widget]
except KeyError: except KeyError:
raise errors.NoWidget("Widget is not in layout") raise errors.NoWidget("Widget is not in layout")
else: else:
@@ -452,11 +516,7 @@ class Compositor:
] ]
return segment_lines return segment_lines
def render( def render(self, regions: list[Region] | None = None) -> RenderableType:
self,
*,
crop: Region | None = None,
) -> SegmentLines:
"""Render a layout. """Render a layout.
Args: Args:
@@ -467,8 +527,15 @@ class Compositor:
""" """
width, height = self.size width, height = self.size
screen_region = Region(0, 0, width, height) screen_region = Region(0, 0, width, height)
if regions:
crop_region = crop.intersection(screen_region) if crop else screen_region # Create a crop regions that surrounds all updates
crop = Region.from_union(regions).intersection(screen_region)
spans = list(self._regions_to_spans(regions))
is_rendered_line = {y for y, _, _ in spans}.__contains__
else:
crop = screen_region
spans = []
is_rendered_line = lambda y: True
_Segment = Segment _Segment = Segment
divide = _Segment.divide divide = _Segment.divide
@@ -480,9 +547,8 @@ class Compositor:
"Callable[[list[int]], dict[int, list[Segment] | None]]", dict.fromkeys "Callable[[list[int]], dict[int, list[Segment] | None]]", dict.fromkeys
) )
# A mapping of cut index to a list of segments for each line # A mapping of cut index to a list of segments for each line
chops: list[dict[int, list[Segment] | None]] = [ chops: list[dict[int, list[Segment] | None]]
fromkeys(cut_set) for cut_set in cuts chops = [fromkeys(cut_set) for cut_set in cuts]
]
# Go through all the renders in reverse order and fill buckets with no render # Go through all the renders in reverse order and fill buckets with no render
renders = self._get_renders(crop) renders = self._get_renders(crop)
@@ -492,6 +558,8 @@ class Compositor:
render_region = intersection(region, clip) render_region = intersection(region, clip)
for y, line in zip(render_region.y_range, lines): for y, line in zip(render_region.y_range, lines):
if not is_rendered_line(y):
continue
first_cut, last_cut = render_region.x_extents first_cut, last_cut = render_region.x_extents
final_cuts = [cut for cut in cuts[y] if (last_cut >= cut >= first_cut)] final_cuts = [cut for cut in cuts[y] if (last_cut >= cut >= first_cut)]
@@ -501,6 +569,7 @@ class Compositor:
else: else:
render_x = render_region.x render_x = render_region.x
relative_cuts = [cut - render_x for cut in final_cuts] relative_cuts = [cut - render_x for cut in final_cuts]
# print(relative_cuts)
_, *cut_segments = divide(line, relative_cuts) _, *cut_segments = divide(line, relative_cuts)
# Since we are painting front to back, the first segments for a cut "wins" # Since we are painting front to back, the first segments for a cut "wins"
@@ -509,24 +578,25 @@ class Compositor:
if chops_line[cut] is None: if chops_line[cut] is None:
chops_line[cut] = segments chops_line[cut] = segments
# Assemble the cut renders in to lists of segments if regions:
crop_x, crop_y, crop_x2, crop_y2 = crop_region.corners crop_y, crop_y2 = crop.y_extents
render_lines = self._assemble_chops(chops[crop_y:crop_y2]) render_lines = self._assemble_chops(chops[crop_y:crop_y2])
render_spans = [
if crop is not None and (crop_x, crop_x2) != (0, width): (y, x1, line_crop(render_lines[y - crop_y], x1, x2))
render_lines = [ for y, x1, x2 in spans
line_crop(line, crop_x, crop_x2) if line else line
for line in render_lines
] ]
return SpansUpdate(render_spans)
return SegmentLines(render_lines, new_lines=True) else:
render_lines = self._assemble_chops(chops)
return LayoutUpdate(render_lines, screen_region)
def __rich_console__( def __rich_console__(
self, console: Console, options: ConsoleOptions self, console: Console, options: ConsoleOptions
) -> RenderResult: ) -> RenderResult:
yield self.render() yield self.render()
def update_widget(self, console: Console, widget: Widget) -> LayoutUpdate | None: def update_widgets(self, widgets: set[Widget]) -> RenderableType | None:
"""Update a given widget in the composition. """Update a given widget in the composition.
Args: Args:
@@ -536,14 +606,12 @@ class Compositor:
Returns: Returns:
LayoutUpdate | None: A renderable or None if nothing to render. LayoutUpdate | None: A renderable or None if nothing to render.
""" """
if widget not in self.regions: regions: list[Region] = []
return None add_region = regions.append
region, clip = self.regions[widget] for widget in self.regions.keys() & widgets:
if not region: region, clip = self.regions[widget]
return None update_region = region.intersection(clip)
update_region = region.intersection(clip) if update_region:
if not update_region: add_region(update_region)
return None update = self.render(regions or None)
update_lines = self.render(crop=update_region).lines
update = LayoutUpdate(update_lines, update_region)
return update return update

View File

@@ -24,6 +24,9 @@ class Layout(ABC):
name: ClassVar[str] = "" name: ClassVar[str] = ""
def __repr__(self) -> str:
return f"<{self.name}>"
@abstractmethod @abstractmethod
def arrange( def arrange(
self, parent: Widget, size: Size, scroll: Offset self, parent: Widget, size: Size, scroll: Offset

View File

@@ -1,48 +0,0 @@
from __future__ import annotations
from collections import defaultdict
from operator import attrgetter
from typing import NamedTuple, Iterable
from .geometry import Region
class InlineRange(NamedTuple):
"""Represents a region on a single line."""
line_index: int
start: int
end: int
def regions_to_ranges(regions: Iterable[Region]) -> Iterable[InlineRange]:
"""Converts the regions to non-overlapping horizontal strips, where each strip
represents the region on a single line. Combining the resulting strips therefore
results in a shape identical to the combined original regions.
Args:
regions (Iterable[Region]): An iterable of Regions.
Returns:
Iterable[InlineRange]: Yields InlineRange objects representing the content on
a single line, with overlaps removed.
"""
inline_ranges: dict[int, list[InlineRange]] = defaultdict(list)
for region_x, region_y, width, height in regions:
for y in range(region_y, region_y + height):
inline_ranges[y].append(
InlineRange(line_index=y, start=region_x, end=region_x + width - 1)
)
get_start = attrgetter("start")
for line_index, ranges in inline_ranges.items():
sorted_ranges = iter(sorted(ranges, key=get_start))
_, start, end = next(sorted_ranges)
for next_line_index, next_start, next_end in sorted_ranges:
if next_start <= end + 1:
end = max(end, next_end)
else:
yield InlineRange(line_index, start, end)
start = next_start
end = next_end
yield InlineRange(line_index, start, end)

View File

@@ -33,6 +33,7 @@ from rich.measure import Measurement
from rich.protocol import is_renderable from rich.protocol import is_renderable
from rich.screen import Screen as ScreenRenderable from rich.screen import Screen as ScreenRenderable
from rich.segment import Segments from rich.segment import Segments
from rich.style import Style
from rich.traceback import Traceback from rich.traceback import Traceback
from . import actions from . import actions
@@ -143,6 +144,9 @@ class App(Generic[ReturnType], DOMNode):
self.driver_class = driver_class or self.get_driver_class() self.driver_class = driver_class or self.get_driver_class()
self._title = title self._title = title
self._screen_stack: list[Screen] = [] self._screen_stack: list[Screen] = []
self._sync_available = (
os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal" and not WINDOWS
)
self.focused: Widget | None = None self.focused: Widget | None = None
self.mouse_over: Widget | None = None self.mouse_over: Widget | None = None
@@ -478,7 +482,7 @@ class App(Generic[ReturnType], DOMNode):
self.stylesheet.update(self) self.stylesheet.update(self)
self.screen.refresh(layout=True) self.screen.refresh(layout=True)
def render(self) -> RenderableType: def render(self, styles: Style) -> RenderableType:
return "" return ""
def query(self, selector: str | None = None) -> DOMQuery: def query(self, selector: str | None = None) -> DOMQuery:
@@ -639,6 +643,7 @@ class App(Generic[ReturnType], DOMNode):
def fatal_error(self) -> None: def fatal_error(self) -> None:
"""Exits the app after an unhandled exception.""" """Exits the app after an unhandled exception."""
self.console.bell()
traceback = Traceback( traceback = Traceback(
show_locals=True, width=None, locals_max_length=5, suppress=[rich] show_locals=True, width=None, locals_max_length=5, suppress=[rich]
) )
@@ -806,20 +811,19 @@ class App(Generic[ReturnType], DOMNode):
def refresh(self, *, repaint: bool = True, layout: bool = False) -> None: def refresh(self, *, repaint: bool = True, layout: bool = False) -> None:
if not self._running: if not self._running:
return return
sync_available = (
os.environ.get("TERM_PROGRAM", "") != "Apple_Terminal" and not WINDOWS
)
if not self._closed: if not self._closed:
console = self.console console = self.console
try: try:
if sync_available: if self._sync_available:
console.file.write("\x1bP=1s\x1b\\") console.file.write("\x1bP=1s\x1b\\")
console.print( console.print(
ScreenRenderable( ScreenRenderable(
Control.home(), self.screen._compositor, Control.home() Control.home(),
self.screen._compositor,
Control.home(),
) )
) )
if sync_available: if self._sync_available:
console.file.write("\x1bP=2s\x1b\\") console.file.write("\x1bP=2s\x1b\\")
console.file.flush() console.file.flush()
except Exception as error: except Exception as error:
@@ -942,14 +946,14 @@ class App(Generic[ReturnType], DOMNode):
action_target = default_namespace or self action_target = default_namespace or self
action_name = target action_name = target
log("action", action) log("<action>", action)
await self.dispatch_action(action_target, action_name, params) await self.dispatch_action(action_target, action_name, params)
async def dispatch_action( async def dispatch_action(
self, namespace: object, action_name: str, params: Any self, namespace: object, action_name: str, params: Any
) -> None: ) -> None:
log( log(
"dispatch_action", "<action>",
namespace=namespace, namespace=namespace,
action_name=action_name, action_name=action_name,
params=params, params=params,

View File

@@ -23,7 +23,7 @@ from rich.color import Color as RichColor
from rich.style import Style from rich.style import Style
from rich.text import Text from rich.text import Text
from textual.suggestions import get_suggestion
from ._color_constants import COLOR_NAME_TO_RGB from ._color_constants import COLOR_NAME_TO_RGB
from .geometry import clamp from .geometry import clamp
@@ -77,6 +77,17 @@ split_pairs4: Callable[[str], tuple[str, str, str, str]] = itemgetter(
class ColorParseError(Exception): class ColorParseError(Exception):
"""A color failed to parse""" """A color failed to parse"""
def __init__(self, message: str, suggested_color: str | None = None):
"""
Creates a new ColorParseError
Args:
message (str): the error message
suggested_color (str | None): a close color we can suggest. Defaults to None.
"""
super().__init__(message)
self.suggested_color = suggested_color
@rich.repr.auto @rich.repr.auto
class Color(NamedTuple): class Color(NamedTuple):
@@ -271,7 +282,14 @@ class Color(NamedTuple):
return cls(*color_from_name) return cls(*color_from_name)
color_match = RE_COLOR.match(color_text) color_match = RE_COLOR.match(color_text)
if color_match is None: if color_match is None:
raise ColorParseError(f"failed to parse {color_text!r} as a color") error_message = f"failed to parse {color_text!r} as a color"
suggested_color = None
if not color_text.startswith("#") and not color_text.startswith("rgb"):
# Seems like we tried to use a color name: let's try to find one that is close enough:
suggested_color = get_suggestion(color_text, COLOR_NAME_TO_RGB.keys())
if suggested_color:
error_message += f"; did you mean '{suggested_color}'?"
raise ColorParseError(error_message, suggested_color)
( (
rgb_hex_triple, rgb_hex_triple,
rgb_hex_quad, rgb_hex_quad,

View File

@@ -70,13 +70,13 @@ class HelpText:
Attributes: Attributes:
summary (str): A succinct summary of the issue. summary (str): A succinct summary of the issue.
bullets (Iterable[Bullet]): Bullet points which provide additional bullets (Iterable[Bullet] | None): Bullet points which provide additional
context around the issue. These are rendered below the summary. context around the issue. These are rendered below the summary. Defaults to None.
""" """
def __init__(self, summary: str, *, bullets: Iterable[Bullet]) -> None: def __init__(self, summary: str, *, bullets: Iterable[Bullet] = None) -> None:
self.summary = summary self.summary = summary
self.bullets = bullets self.bullets = bullets or []
def __rich_console__( def __rich_console__(
self, console: Console, options: ConsoleOptions self, console: Console, options: ConsoleOptions

View File

@@ -4,6 +4,7 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable from typing import Iterable
from textual.color import ColorParseError
from textual.css._help_renderables import Example, Bullet, HelpText from textual.css._help_renderables import Example, Bullet, HelpText
from textual.css.constants import ( from textual.css.constants import (
VALID_BORDER, VALID_BORDER,
@@ -144,13 +145,13 @@ def property_invalid_value_help_text(
HelpText: Renderable for displaying the help text for this property HelpText: Renderable for displaying the help text for this property
""" """
property_name = _contextualize_property_name(property_name, context) property_name = _contextualize_property_name(property_name, context)
bullets = [] summary = f"Invalid CSS property [i]{property_name}[/]"
if suggested_property_name: if suggested_property_name:
suggested_property_name = _contextualize_property_name( suggested_property_name = _contextualize_property_name(
suggested_property_name, context suggested_property_name, context
) )
bullets.append(Bullet(f'Did you mean "{suggested_property_name}"?')) summary += f'. Did you mean "{suggested_property_name}"?'
return HelpText(f"Invalid CSS property [i]{property_name}[/]", bullets=bullets) return HelpText(summary)
def spacing_wrong_number_of_values_help_text( def spacing_wrong_number_of_values_help_text(
@@ -303,6 +304,8 @@ def string_enum_help_text(
def color_property_help_text( def color_property_help_text(
property_name: str, property_name: str,
context: StylingContext, context: StylingContext,
*,
error: Exception = None,
) -> HelpText: ) -> HelpText:
"""Help text to show when the user supplies an invalid value for a color """Help text to show when the user supplies an invalid value for a color
property. For example, an unparseable color string. property. For example, an unparseable color string.
@@ -310,13 +313,20 @@ def color_property_help_text(
Args: Args:
property_name (str): The name of the property property_name (str): The name of the property
context (StylingContext | None): The context the property is being used in. context (StylingContext | None): The context the property is being used in.
error (ColorParseError | None): The error that caused this help text to be displayed. Defaults to None.
Returns: Returns:
HelpText: Renderable for displaying the help text for this property HelpText: Renderable for displaying the help text for this property
""" """
property_name = _contextualize_property_name(property_name, context) property_name = _contextualize_property_name(property_name, context)
summary = f"Invalid value for the [i]{property_name}[/] property"
suggested_color = (
error.suggested_color if error and isinstance(error, ColorParseError) else None
)
if suggested_color:
summary += f'. Did you mean "{suggested_color}"?'
return HelpText( return HelpText(
summary=f"Invalid value for the [i]{property_name}[/] property", summary=summary,
bullets=[ bullets=[
Bullet( Bullet(
f"The [i]{property_name}[/] property can only be set to a valid color" f"The [i]{property_name}[/] property can only be set to a valid color"

View File

@@ -782,10 +782,12 @@ class ColorProperty:
elif isinstance(color, str): elif isinstance(color, str):
try: try:
parsed_color = Color.parse(color) parsed_color = Color.parse(color)
except ColorParseError: except ColorParseError as error:
raise StyleValueError( raise StyleValueError(
f"Invalid color value '{color}'", f"Invalid color value '{color}'",
help_text=color_property_help_text(self.name, context="inline"), help_text=color_property_help_text(
self.name, context="inline", error=error
),
) )
if obj.set_rule(self.name, parsed_color): if obj.set_rule(self.name, parsed_color):
obj.refresh() obj.refresh()

View File

@@ -572,9 +572,11 @@ class StylesBuilder:
elif token.name in ("color", "token"): elif token.name in ("color", "token"):
try: try:
color = Color.parse(token.value) color = Color.parse(token.value)
except Exception: except Exception as error:
self.error( self.error(
name, token, color_property_help_text(name, context="css") name,
token,
color_property_help_text(name, context="css", error=error),
) )
else: else:
self.error(name, token, color_property_help_text(name, context="css")) self.error(name, token, color_property_help_text(name, context="css"))

View File

@@ -241,7 +241,13 @@ class StylesBase(ABC):
Returns: Returns:
Spacing: Space around widget. Spacing: Space around widget.
""" """
spacing = Spacing() + self.padding + self.border.spacing spacing = self.padding + self.border.spacing
return spacing
@property
def content_gutter(self) -> Spacing:
"""The spacing that surrounds the content area of the widget."""
spacing = self.padding + self.border.spacing + self.margin
return spacing return spacing
@abstractmethod @abstractmethod

View File

@@ -7,7 +7,7 @@ from pathlib import Path, PurePath
from typing import cast, Iterable from typing import cast, Iterable
import rich.repr import rich.repr
from rich.console import RenderableType, Console, ConsoleOptions from rich.console import RenderableType, RenderResult, Console, ConsoleOptions
from rich.highlighter import ReprHighlighter from rich.highlighter import ReprHighlighter
from rich.markup import render from rich.markup import render
from rich.padding import Padding from rich.padding import Padding
@@ -68,10 +68,10 @@ class StylesheetErrors:
def __rich_console__( def __rich_console__(
self, console: Console, options: ConsoleOptions self, console: Console, options: ConsoleOptions
) -> RenderableType: ) -> RenderResult:
error_count = 0 error_count = 0
for rule in self.rules: for rule in self.rules:
for is_last, (token, message) in loop_last(rule.errors): for token, message in rule.errors:
error_count += 1 error_count += 1
if token.path: if token.path:
@@ -297,7 +297,6 @@ class Stylesheet:
for name, specificity_rules in rule_attributes.items() for name, specificity_rules in rule_attributes.items()
}, },
) )
self.replace_rules(node, node_rules, animate=animate) self.replace_rules(node, node_rules, animate=animate)
@classmethod @classmethod
@@ -363,8 +362,9 @@ class Stylesheet:
setattr(base_styles, key, new_value) setattr(base_styles, key, new_value)
else: else:
# Not animated, so we apply the rules directly # Not animated, so we apply the rules directly
get_rule = rules.get
for key in modified_rule_keys: for key in modified_rule_keys:
setattr(base_styles, key, rules.get(key)) setattr(base_styles, key, get_rule(key))
def update(self, root: DOMNode, animate: bool = False) -> None: def update(self, root: DOMNode, animate: bool = False) -> None:
"""Update a node and its children.""" """Update a node and its children."""

View File

@@ -395,9 +395,9 @@ class Blur(Event, bubble=False):
pass pass
class DescendantFocus(Event, bubble=True): class DescendantFocus(Event, verbosity=2, bubble=True):
pass pass
class DescendantBlur(Event, bubble=True): class DescendantBlur(Event, verbosity=2, bubble=True):
pass pass

View File

@@ -6,7 +6,7 @@ Functions and classes to manage terminal geometry (anything involving coordinate
from __future__ import annotations from __future__ import annotations
from typing import Any, cast, NamedTuple, Tuple, Union, TypeVar from typing import Any, cast, Iterable, NamedTuple, Tuple, Union, TypeVar
SpacingDimensions = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int, int]] SpacingDimensions = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int, int]]
@@ -181,6 +181,24 @@ class Region(NamedTuple):
width: int = 0 width: int = 0
height: int = 0 height: int = 0
@classmethod
def from_union(cls, regions: list[Region]) -> Region:
"""Create a Region from the union of other regions.
Args:
regions (Iterable[Region]): One or more regions.
Returns:
Region: A Region that encloses all other regions.
"""
if not regions:
raise ValueError("At least one region expected")
min_x = min([region.x for region in regions])
max_x = max([x + width for x, _y, width, _height in regions])
min_y = min([region.y for region in regions])
max_y = max([y + height for _x, y, _width, height in regions])
return cls(min_x, min_y, max_x - min_x, max_y - min_y)
@classmethod @classmethod
def from_corners(cls, x1: int, y1: int, x2: int, y2: int) -> Region: def from_corners(cls, x1: int, y1: int, x2: int, y2: int) -> Region:
"""Construct a Region form the top left and bottom right corners. """Construct a Region form the top left and bottom right corners.
@@ -257,6 +275,24 @@ class Region(NamedTuple):
"""Get the start point of the region.""" """Get the start point of the region."""
return Offset(self.x, self.y) return Offset(self.x, self.y)
@property
def bottom_left(self) -> Offset:
"""Bottom left offset of the region."""
x, y, _width, height = self
return Offset(x, y + height)
@property
def top_right(self) -> Offset:
"""Top right offset of the region."""
x, y, width, _height = self
return Offset(x + width, y)
@property
def bottom_right(self) -> Offset:
"""Bottom right of the region."""
x, y, width, height = self
return Offset(x + width, y + height)
@property @property
def size(self) -> Size: def size(self) -> Size:
"""Get the size of the region.""" """Get the size of the region."""
@@ -274,17 +310,17 @@ class Region(NamedTuple):
@property @property
def x_range(self) -> range: def x_range(self) -> range:
"""A range object for X coordinates""" """A range object for X coordinates."""
return range(self.x, self.x + self.width) return range(self.x, self.x + self.width)
@property @property
def y_range(self) -> range: def y_range(self) -> range:
"""A range object for Y coordinates""" """A range object for Y coordinates."""
return range(self.y, self.y + self.height) return range(self.y, self.y + self.height)
@property @property
def reset_origin(self) -> Region: def reset_origin(self) -> Region:
"""An region of the same size at the origin.""" """An region of the same size at (0, 0)."""
_, _, width, height = self _, _, width, height = self
return Region(0, 0, width, height) return Region(0, 0, width, height)

View File

@@ -149,8 +149,11 @@ class MessagePump:
callback: TimerCallback = None, callback: TimerCallback = None,
*, *,
name: str | None = None, name: str | None = None,
pause: bool = False,
) -> Timer: ) -> Timer:
timer = Timer(self, delay, self, name=name, callback=callback, repeat=0) timer = Timer(
self, delay, self, name=name, callback=callback, repeat=0, pause=pause
)
self._child_tasks.add(timer.start()) self._child_tasks.add(timer.start())
return timer return timer
@@ -161,9 +164,16 @@ class MessagePump:
*, *,
name: str | None = None, name: str | None = None,
repeat: int = 0, repeat: int = 0,
pause: bool = False,
): ):
timer = Timer( timer = Timer(
self, interval, self, name=name, callback=callback, repeat=repeat or None self,
interval,
self,
name=name,
callback=callback,
repeat=repeat or None,
pause=pause,
) )
self._child_tasks.add(timer.start()) self._child_tasks.add(timer.start())
return timer return timer

View File

@@ -8,7 +8,7 @@ from rich.style import Style
from . import events, messages, errors from . import events, messages, errors
from .geometry import Offset, Region from .geometry import Offset, Region
from ._compositor import Compositor from ._compositor import Compositor, MapGeometry
from .reactive import Reactive from .reactive import Reactive
from .widget import Widget from .widget import Widget
@@ -18,14 +18,14 @@ class Screen(Widget):
"""A widget for the root of the app.""" """A widget for the root of the app."""
CSS = """ CSS = """
Screen { Screen {
layout: vertical; layout: vertical;
overflow-y: auto; overflow-y: auto;
background: $surface; background: $surface;
color: $text-surface; color: $text-surface;
} }
""" """
dark = Reactive(False) dark = Reactive(False)
@@ -33,13 +33,13 @@ class Screen(Widget):
def __init__(self, name: str | None = None, id: str | None = None) -> None: def __init__(self, name: str | None = None, id: str | None = None) -> None:
super().__init__(name=name, id=id) super().__init__(name=name, id=id)
self._compositor = Compositor() self._compositor = Compositor()
self._dirty_widgets: list[Widget] = [] self._dirty_widgets: set[Widget] = set()
def watch_dark(self, dark: bool) -> None: def watch_dark(self, dark: bool) -> None:
pass pass
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return self.app.render() return self.app.render(style)
def get_offset(self, widget: Widget) -> Offset: def get_offset(self, widget: Widget) -> Offset:
"""Get the absolute offset of a given Widget. """Get the absolute offset of a given Widget.
@@ -76,7 +76,7 @@ class Screen(Widget):
""" """
return self._compositor.get_style_at(x, y) return self._compositor.get_style_at(x, y)
def get_widget_region(self, widget: Widget) -> Region: def find_widget(self, widget: Widget) -> MapGeometry:
"""Get the screen region of a Widget. """Get the screen region of a Widget.
Args: Args:
@@ -85,24 +85,32 @@ class Screen(Widget):
Returns: Returns:
Region: Region relative to screen. Region: Region relative to screen.
""" """
return self._compositor.get_widget_region(widget) return self._compositor.find_widget(widget)
def on_idle(self, event: events.Idle) -> None: def on_idle(self, event: events.Idle) -> None:
# Check for any widgets marked as 'dirty' (needs a repaint) # Check for any widgets marked as 'dirty' (needs a repaint)
if self._dirty_widgets: if self._dirty_widgets:
for widget in self._dirty_widgets: self._update_timer.resume()
# Repaint widgets
# TODO: Combine these in to a single update. def _on_update(self) -> None:
display_update = self._compositor.update_widget(self.console, widget) """Called by the _update_timer."""
if display_update is not None:
self.app.display(display_update) # Render widgets together
# Reset dirty list if self._dirty_widgets:
self.log(dirty=self._dirty_widgets)
display_update = self._compositor.update_widgets(self._dirty_widgets)
if display_update is not None:
self.app.display(display_update)
self._dirty_widgets.clear() self._dirty_widgets.clear()
self._update_timer.pause()
def refresh_layout(self) -> None: def refresh_layout(self) -> None:
"""Refresh the layout (can change size and positions of widgets).""" """Refresh the layout (can change size and positions of widgets)."""
if not self.size: if not self.size:
return return
# This paint the entire screen, so replaces the batched dirty widgets
self._update_timer.pause()
self._dirty_widgets.clear()
try: try:
hidden, shown, resized = self._compositor.reflow(self, self.size) hidden, shown, resized = self._compositor.reflow(self, self.size)
@@ -133,30 +141,31 @@ class Screen(Widget):
self.app.on_exception(error) self.app.on_exception(error)
return return
self.app.refresh() self.app.refresh()
self._dirty_widgets.clear()
async def handle_update(self, message: messages.Update) -> None: async def handle_update(self, message: messages.Update) -> None:
message.stop() message.stop()
widget = message.widget widget = message.widget
assert isinstance(widget, Widget) assert isinstance(widget, Widget)
self._dirty_widgets.append(widget) self._dirty_widgets.add(widget)
self.check_idle() self.check_idle()
async def handle_layout(self, message: messages.Layout) -> None: async def handle_layout(self, message: messages.Layout) -> None:
message.stop() message.stop()
self.refresh_layout() self.refresh_layout()
def on_mount(self, event: events.Mount) -> None:
self._update_timer = self.set_interval(1 / 20, self._on_update, pause=True)
async def on_resize(self, event: events.Resize) -> None: async def on_resize(self, event: events.Resize) -> None:
self.size_updated(event.size, event.virtual_size, event.container_size) self.size_updated(event.size, event.virtual_size, event.container_size)
self.refresh_layout() self.refresh_layout()
event.stop() event.stop()
async def _on_mouse_move(self, event: events.MouseMove) -> None: async def _on_mouse_move(self, event: events.MouseMove) -> None:
try: try:
if self.app.mouse_captured: if self.app.mouse_captured:
widget = self.app.mouse_captured widget = self.app.mouse_captured
region = self.get_widget_region(widget) region = self.find_widget(widget).region
else: else:
widget, region = self.get_widget_at(event.x, event.y) widget, region = self.get_widget_at(event.x, event.y)
except errors.NoWidget: except errors.NoWidget:
@@ -195,7 +204,7 @@ class Screen(Widget):
try: try:
if self.app.mouse_captured: if self.app.mouse_captured:
widget = self.app.mouse_captured widget = self.app.mouse_captured
region = self.get_widget_region(widget) region = self.find_widget(widget).region
else: else:
widget, region = self.get_widget_at(event.x, event.y) widget, region = self.get_widget_at(event.x, event.y)
except errors.NoWidget: except errors.NoWidget:

View File

@@ -205,9 +205,9 @@ class ScrollBar(Widget):
yield "window_size", self.window_size yield "window_size", self.window_size
yield "position", self.position yield "position", self.position
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
styles = self.parent.styles styles = self.parent.styles
style = Style( scrollbar_style = Style(
bgcolor=( bgcolor=(
styles.scrollbar_background_hover.rich_color styles.scrollbar_background_hover.rich_color
if self.mouse_over if self.mouse_over
@@ -224,12 +224,9 @@ class ScrollBar(Widget):
window_size=self.window_size, window_size=self.window_size,
position=self.position, position=self.position,
vertical=self.vertical, vertical=self.vertical,
style=style, style=scrollbar_style,
) )
async def on_event(self, event) -> None:
await super().on_event(event)
async def on_enter(self, event: events.Enter) -> None: async def on_enter(self, event: events.Enter) -> None:
self.mouse_over = True self.mouse_over = True
@@ -284,7 +281,6 @@ class ScrollBar(Widget):
if __name__ == "__main__": if __name__ == "__main__":
from rich.console import Console from rich.console import Console
from rich.segment import Segments
console = Console() console = Console()
bar = ScrollBarRender() bar = ScrollBarRender()

View File

@@ -26,6 +26,7 @@ from .box_model import BoxModel, get_box_model
from .color import Color from .color import Color
from ._context import active_app from ._context import active_app
from ._types import Lines from ._types import Lines
from .css.styles import Styles
from .dom import DOMNode from .dom import DOMNode
from .geometry import clamp, Offset, Region, Size from .geometry import clamp, Offset, Region, Size
from .layouts.vertical import VerticalLayout from .layouts.vertical import VerticalLayout
@@ -83,6 +84,7 @@ class Widget(DOMNode):
self._virtual_size = Size(0, 0) self._virtual_size = Size(0, 0)
self._container_size = Size(0, 0) self._container_size = Size(0, 0)
self._layout_required = False self._layout_required = False
self._repaint_required = False
self._default_layout = VerticalLayout() self._default_layout = VerticalLayout()
self._animate: BoundAnimator | None = None self._animate: BoundAnimator | None = None
self._reactive_watches: dict[str, Callable] = {} self._reactive_watches: dict[str, Callable] = {}
@@ -159,7 +161,7 @@ class Widget(DOMNode):
int: The optimal width of the content. int: The optimal width of the content.
""" """
console = self.app.console console = self.app.console
renderable = self.render() renderable = self.render(self.styles.rich_style)
measurement = Measurement.get(console, console.options, renderable) measurement = Measurement.get(console, console.options, renderable)
return measurement.maximum return measurement.maximum
@@ -176,7 +178,7 @@ class Widget(DOMNode):
Returns: Returns:
int: The height of the content. int: The height of the content.
""" """
renderable = self.render() renderable = self.render(self.styles.rich_style)
options = self.console.options.update_width(width) options = self.console.options.update_width(width)
segments = self.console.render(renderable, options) segments = self.console.render(renderable, options)
# Cheaper than counting the lines returned from render_lines! # Cheaper than counting the lines returned from render_lines!
@@ -272,6 +274,8 @@ class Widget(DOMNode):
self.show_horizontal_scrollbar = show_horizontal self.show_horizontal_scrollbar = show_horizontal
self.show_vertical_scrollbar = show_vertical self.show_vertical_scrollbar = show_vertical
self.horizontal_scrollbar.display = show_horizontal
self.vertical_scrollbar.display = show_vertical
@property @property
def scrollbars_enabled(self) -> tuple[bool, bool]: def scrollbars_enabled(self) -> tuple[bool, bool]:
@@ -298,17 +302,20 @@ class Widget(DOMNode):
y: float | None = None, y: float | None = None,
*, *,
animate: bool = True, animate: bool = True,
speed: float | None = None,
duration: float | None = None,
) -> bool: ) -> bool:
"""Scroll to a given (absolute) coordinate, optionally animating. """Scroll to a given (absolute) coordinate, optionally animating.
Args: Args:
scroll_x (int | None, optional): X coordinate (column) to scroll to, or ``None`` for no change. Defaults to None. x (int | None, optional): X coordinate (column) to scroll to, or ``None`` for no change. Defaults to None.
scroll_y (int | None, optional): Y coordinate (row) to scroll to, or ``None`` for no change. Defaults to None. y (int | None, optional): Y coordinate (row) to scroll to, or ``None`` for no change. Defaults to None.
animate (bool, optional): Animate to new scroll position. Defaults to False. animate (bool, optional): Animate to new scroll position. Defaults to False.
"""
scrolled_x = False Returns:
scrolled_y = False bool: True if the scroll position changed, otherwise False.
"""
scrolled_x = scrolled_y = False
if animate: if animate:
# TODO: configure animation speed # TODO: configure animation speed
@@ -316,67 +323,151 @@ class Widget(DOMNode):
self.scroll_target_x = x self.scroll_target_x = x
if x != self.scroll_x: if x != self.scroll_x:
self.animate( self.animate(
"scroll_x", self.scroll_target_x, speed=80, easing="out_cubic" "scroll_x",
self.scroll_target_x,
speed=speed,
duration=duration,
easing="out_cubic",
) )
scrolled_x = True scrolled_x = True
if y is not None: if y is not None:
self.scroll_target_y = y self.scroll_target_y = y
if y != self.scroll_y: if y != self.scroll_y:
self.animate( self.animate(
"scroll_y", self.scroll_target_y, speed=80, easing="out_cubic" "scroll_y",
self.scroll_target_y,
speed=speed,
duration=duration,
easing="out_cubic",
) )
scrolled_y = True scrolled_y = True
else: else:
if x is not None: if x is not None:
scroll_x = self.scroll_x
self.scroll_target_x = self.scroll_x = x self.scroll_target_x = self.scroll_x = x
if x != self.scroll_x: scrolled_x = scroll_x != self.scroll_x
scrolled_x = True
if y is not None: if y is not None:
scroll_y = self.scroll_y
self.scroll_target_y = self.scroll_y = y self.scroll_target_y = self.scroll_y = y
if y != self.scroll_y: scrolled_y = scroll_y != self.scroll_y
scrolled_y = True if scrolled_x or scrolled_y:
self.refresh(repaint=False, layout=True) self.refresh(repaint=False, layout=True)
return scrolled_x or scrolled_y return scrolled_x or scrolled_y
def scroll_home(self, animate: bool = True) -> bool: def scroll_relative(
self,
x: float | None = None,
y: float | None = None,
*,
animate: bool = True,
speed: float | None = None,
duration: float | None = None,
) -> bool:
"""Scroll relative to current position.
Args:
x (int | None, optional): X distance (columns) to scroll, or ``None`` for no change. Defaults to None.
y (int | None, optional): Y distance (rows) to scroll, or ``None`` for no change. Defaults to None.
animate (bool, optional): Animate to new scroll position. Defaults to False.
Returns:
bool: True if the scroll position changed, otherwise False.
"""
return self.scroll_to(
None if x is None else (self.scroll_x + x),
None if y is None else (self.scroll_y + y),
animate=animate,
speed=speed,
duration=duration,
)
def scroll_home(self, *, animate: bool = True) -> bool:
return self.scroll_to(0, 0, animate=animate) return self.scroll_to(0, 0, animate=animate)
def scroll_end(self, animate: bool = True) -> bool: def scroll_end(self, *, animate: bool = True) -> bool:
return self.scroll_to(0, self.max_scroll_y, animate=animate) return self.scroll_to(0, self.max_scroll_y, animate=animate)
def scroll_left(self, animate: bool = True) -> bool: def scroll_left(self, *, animate: bool = True) -> bool:
return self.scroll_to(x=self.scroll_target_x - 1, animate=animate) return self.scroll_to(x=self.scroll_target_x - 1, animate=animate)
def scroll_right(self, animate: bool = True) -> bool: def scroll_right(self, *, animate: bool = True) -> bool:
return self.scroll_to(x=self.scroll_target_x + 1, animate=animate) return self.scroll_to(x=self.scroll_target_x + 1, animate=animate)
def scroll_up(self, animate: bool = True) -> bool: def scroll_up(self, *, animate: bool = True) -> bool:
return self.scroll_to(y=self.scroll_target_y + 1, animate=animate) return self.scroll_to(y=self.scroll_target_y + 1, animate=animate)
def scroll_down(self, animate: bool = True) -> bool: def scroll_down(self, *, animate: bool = True) -> bool:
return self.scroll_to(y=self.scroll_target_y - 1, animate=animate) return self.scroll_to(y=self.scroll_target_y - 1, animate=animate)
def scroll_page_up(self, animate: bool = True) -> bool: def scroll_page_up(self, *, animate: bool = True) -> bool:
return self.scroll_to( return self.scroll_to(
y=self.scroll_target_y - self.container_size.height, animate=animate y=self.scroll_target_y - self.container_size.height, animate=animate
) )
def scroll_page_down(self, animate: bool = True) -> bool: def scroll_page_down(self, *, animate: bool = True) -> bool:
return self.scroll_to( return self.scroll_to(
y=self.scroll_target_y + self.container_size.height, animate=animate y=self.scroll_target_y + self.container_size.height, animate=animate
) )
def scroll_page_left(self, animate: bool = True) -> bool: def scroll_page_left(self, *, animate: bool = True) -> bool:
return self.scroll_to( return self.scroll_to(
x=self.scroll_target_x - self.container_size.width, animate=animate x=self.scroll_target_x - self.container_size.width, animate=animate
) )
def scroll_page_right(self, animate: bool = True) -> bool: def scroll_page_right(self, *, animate: bool = True) -> bool:
return self.scroll_to( return self.scroll_to(
x=self.scroll_target_x + self.container_size.width, animate=animate x=self.scroll_target_x + self.container_size.width, animate=animate
) )
def scroll_to_widget(self, widget: Widget, *, animate: bool = True) -> bool:
"""Scroll so that a child widget is in the visible area.
Args:
widget (Widget): A Widget in the children.
animate (bool, optional): True to animate, or False to jump. Defaults to True.
Returns:
bool: True if the scroll position changed, otherwise False.
"""
screen = self.screen
try:
widget_geometry = screen.find_widget(widget)
container_geometry = screen.find_widget(self)
except errors.NoWidget:
return False
widget_region = widget.content_region + widget_geometry.region.origin
container_region = self.content_region + container_geometry.region.origin
if widget_region in container_region:
# Widget is visible, nothing to do
return False
# We can either scroll so the widget is at the top of the container, or so that
# it is at the bottom. We want to pick which has the shortest distance
top_delta = widget_region.origin - container_region.origin
bottom_delta = widget_region.origin - (
container_region.origin
+ Offset(0, container_region.height - widget_region.height)
)
if widget_region.width > container_region.width:
delta_x = top_delta.x
else:
delta_x = min(top_delta.x, bottom_delta.x, key=abs)
if widget_region.height > container_region.height:
delta_y = top_delta.y
else:
delta_y = min(top_delta.y, bottom_delta.y, key=abs)
return self.scroll_relative(
delta_x or None, delta_y or None, animate=animate, duration=0.2
)
def __init_subclass__( def __init_subclass__(
cls, can_focus: bool = True, can_focus_children: bool = True cls, can_focus: bool = True, can_focus_children: bool = True
) -> None: ) -> None:
@@ -465,8 +556,7 @@ class Widget(DOMNode):
Returns: Returns:
RenderableType: A new renderable. RenderableType: A new renderable.
""" """
renderable = self.render(self.styles.rich_style)
renderable = self.render()
styles = self.styles styles = self.styles
parent_styles = self.parent.styles parent_styles = self.parent.styles
@@ -479,11 +569,12 @@ class Widget(DOMNode):
horizontal, vertical = content_align horizontal, vertical = content_align
renderable = Align(renderable, horizontal, vertical=vertical) renderable = Align(renderable, horizontal, vertical=vertical)
renderable = Padding(renderable, styles.padding)
renderable_text_style = parent_text_style + text_style renderable_text_style = parent_text_style + text_style
if renderable_text_style: if renderable_text_style:
renderable = Styled(renderable, renderable_text_style) style = Style.from_color(text_style.color, text_style.bgcolor)
renderable = Styled(renderable, style)
renderable = Padding(renderable, styles.padding, style=renderable_text_style)
if styles.border: if styles.border:
renderable = Border( renderable = Border(
@@ -517,14 +608,28 @@ class Widget(DOMNode):
def container_size(self) -> Size: def container_size(self) -> Size:
return self._container_size return self._container_size
@property
def content_region(self) -> Region:
"""A region relative to the Widget origin that contains the content."""
x, y = self.styles.content_gutter.top_left
width, height = self._container_size
return Region(x, y, width, height)
@property
def content_offset(self) -> Offset:
"""An offset from the Widget origin where the content begins."""
x, y = self.styles.content_gutter.top_left
return Offset(x, y)
@property @property
def virtual_size(self) -> Size: def virtual_size(self) -> Size:
return self._virtual_size return self._virtual_size
@property @property
def region(self) -> Region: def region(self) -> Region:
"""The region occupied by this widget, relative to the Screen."""
try: try:
return self.screen._compositor.get_widget_region(self) return self.screen.find_widget(self).region
except errors.NoWidget: except errors.NoWidget:
return Region() return Region()
@@ -632,8 +737,10 @@ class Widget(DOMNode):
if self._dirty_regions: if self._dirty_regions:
self._render_lines() self._render_lines()
if self.is_container: if self.is_container:
self.horizontal_scrollbar.refresh() if self.show_horizontal_scrollbar:
self.vertical_scrollbar.refresh() self.horizontal_scrollbar.refresh()
if self.show_vertical_scrollbar:
self.vertical_scrollbar.refresh()
lines = self._render_cache.lines[start:end] lines = self._render_cache.lines[start:end]
return lines return lines
@@ -669,21 +776,19 @@ class Widget(DOMNode):
self._layout_required = True self._layout_required = True
if repaint: if repaint:
self.set_dirty() self.set_dirty()
self._repaint_required = True
self.check_idle() self.check_idle()
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
"""Get renderable for widget. """Get renderable for widget.
Args:
style (Styles): The Styles object for this Widget.
Returns: Returns:
RenderableType: Any renderable RenderableType: Any renderable
""" """
return "" if self.is_container else self.css_identifier_styled
# Default displays a pretty repr in the center of the screen
if self.is_container:
return ""
return self.css_identifier_styled
async def action(self, action: str, *params) -> None: async def action(self, action: str, *params) -> None:
await self.app.action(action, self) await self.app.action(action, self)
@@ -705,8 +810,9 @@ class Widget(DOMNode):
if self.check_layout(): if self.check_layout():
self._reset_check_layout() self._reset_check_layout()
self.screen.post_message_no_wait(messages.Layout(self)) self.screen.post_message_no_wait(messages.Layout(self))
elif self._dirty_regions: elif self._repaint_required:
self.emit_no_wait(messages.Update(self, self)) self.emit_no_wait(messages.Update(self, self))
self._repaint_required = False
def focus(self) -> None: def focus(self) -> None:
"""Give input focus to this widget.""" """Give input focus to this widget."""
@@ -766,6 +872,8 @@ class Widget(DOMNode):
def on_descendant_focus(self, event: events.DescendantFocus) -> None: def on_descendant_focus(self, event: events.DescendantFocus) -> None:
self.descendant_has_focus = True self.descendant_has_focus = True
if self.is_container and isinstance(event.sender, Widget):
self.scroll_to_widget(event.sender, animate=True)
def on_descendant_blur(self, event: events.DescendantBlur) -> None: def on_descendant_blur(self, event: events.DescendantBlur) -> None:
self.descendant_has_focus = False self.descendant_has_focus = False

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
from typing import cast from typing import cast
from rich.console import RenderableType from rich.console import RenderableType
from rich.text import Text from rich.style import Style
from rich.text import Text, TextType
from .. import events from .. import events
from ..message import Message from ..message import Message
@@ -24,8 +25,7 @@ class Button(Widget, can_focus=True):
color: $text-primary; color: $text-primary;
content-align: center middle; content-align: center middle;
border: tall $primary-lighten-3; border: tall $primary-lighten-3;
margin: 1 0;
margin: 1;
text-style: bold; text-style: bold;
} }
@@ -48,7 +48,7 @@ class Button(Widget, can_focus=True):
def __init__( def __init__(
self, self,
label: RenderableType | None = None, label: TextType | None = None,
disabled: bool = False, disabled: bool = False,
*, *,
name: str | None = None, name: str | None = None,
@@ -57,7 +57,11 @@ class Button(Widget, can_focus=True):
): ):
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes)
self.label = self.css_identifier_styled if label is None else label if label is None:
label = self.css_identifier_styled
self.label: Text = label
self.disabled = disabled self.disabled = disabled
if disabled: if disabled:
self.add_class("-disabled") self.add_class("-disabled")
@@ -70,8 +74,10 @@ class Button(Widget, can_focus=True):
return Text.from_markup(label) return Text.from_markup(label)
return label return label
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return self.label label = self.label.copy()
label.stylize(style)
return label
async def on_click(self, event: events.Click) -> None: async def on_click(self, event: events.Click) -> None:
event.stop() event.stop()

View File

@@ -59,7 +59,7 @@ class Footer(Widget):
text.append_text(key_text) text.append_text(key_text)
return text return text
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
if self._key_text is None: if self._key_text is None:
self._key_text = self.make_key_text() self._key_text = self.make_key_text()
return self._key_text return self._key_text

View File

@@ -6,7 +6,7 @@ from logging import getLogger
from rich.console import RenderableType from rich.console import RenderableType
from rich.panel import Panel from rich.panel import Panel
from rich.repr import Result from rich.repr import Result
from rich.style import StyleType from rich.style import StyleType, Style
from rich.table import Table from rich.table import Table
from .. import events from .. import events
@@ -49,7 +49,7 @@ class Header(Widget):
def get_clock(self) -> str: def get_clock(self) -> str:
return datetime.now().time().strftime("%X") return datetime.now().time().strftime("%X")
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
header_table = Table.grid(padding=(0, 1), expand=True) header_table = Table.grid(padding=(0, 1), expand=True)
header_table.style = self.style header_table.style = self.style
header_table.add_column(justify="left", ratio=0, width=8) header_table.add_column(justify="left", ratio=0, width=8)

View File

@@ -6,9 +6,8 @@ from rich.console import RenderableType
from rich.panel import Panel from rich.panel import Panel
from rich.pretty import Pretty from rich.pretty import Pretty
import rich.repr import rich.repr
from rich.style import Style
from .. import log
from .. import events from .. import events
from ..reactive import Reactive from ..reactive import Reactive
from ..widget import Widget from ..widget import Widget
@@ -19,22 +18,23 @@ class Placeholder(Widget, can_focus=True):
has_focus: Reactive[bool] = Reactive(False) has_focus: Reactive[bool] = Reactive(False)
mouse_over: Reactive[bool] = Reactive(False) mouse_over: Reactive[bool] = Reactive(False)
style: Reactive[str] = Reactive("")
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
yield from super().__rich_repr__() yield from super().__rich_repr__()
yield "has_focus", self.has_focus, False yield "has_focus", self.has_focus, False
yield "mouse_over", self.mouse_over, False yield "mouse_over", self.mouse_over, False
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
# Apply colours only inside render_styled
# Pass the full RICH style object into `render` - not the `Styles`
return Panel( return Panel(
Align.center( Align.center(
Pretty(self, no_wrap=True, overflow="ellipsis"), vertical="middle" Pretty(self, no_wrap=True, overflow="ellipsis"),
vertical="middle",
), ),
title=self.__class__.__name__, title=self.__class__.__name__,
border_style="green" if self.mouse_over else "blue", border_style="green" if self.mouse_over else "blue",
box=box.HEAVY if self.has_focus else box.ROUNDED, box=box.HEAVY if self.has_focus else box.ROUNDED,
style=self.style,
) )
async def on_focus(self, event: events.Focus) -> None: async def on_focus(self, event: events.Focus) -> None:

View File

@@ -1,9 +1,8 @@
from __future__ import annotations from __future__ import annotations
from rich.console import RenderableType from rich.console import RenderableType
from rich.padding import Padding, PaddingDimensions from rich.style import Style
from rich.style import StyleType
from rich.styled import Styled
from ..widget import Widget from ..widget import Widget
@@ -15,20 +14,13 @@ class Static(Widget):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
style: StyleType = "",
padding: PaddingDimensions = 0,
) -> None: ) -> None:
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes)
self.renderable = renderable self.renderable = renderable
self.style = style
self.padding = padding
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
renderable = self.renderable return self.renderable
if self.padding:
renderable = Padding(renderable, self.padding)
return Styled(renderable, self.style)
async def update(self, renderable: RenderableType) -> None: def update(self, renderable: RenderableType) -> None:
self.renderable = renderable self.renderable = renderable
self.refresh() self.refresh(layout=True)

View File

@@ -5,12 +5,11 @@ from typing import Generic, Iterator, NewType, TypeVar
import rich.repr import rich.repr
from rich.console import RenderableType from rich.console import RenderableType
from rich.style import Style
from rich.text import Text, TextType from rich.text import Text, TextType
from rich.tree import Tree from rich.tree import Tree
from rich.padding import PaddingDimensions from rich.padding import PaddingDimensions
from .. import log
from .. import events
from ..reactive import Reactive from ..reactive import Reactive
from .._types import MessageTarget from .._types import MessageTarget
from ..widget import Widget from ..widget import Widget
@@ -249,7 +248,7 @@ class TreeControl(Generic[NodeDataType], Widget):
push(iter(node.children)) push(iter(node.children))
return None return None
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return self._tree return self._tree
def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType: def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType:

View File

@@ -330,7 +330,7 @@ class Tabs(Widget):
""" """
return next((i for i, tab in enumerate(self.tabs) if tab.name == tab_name), 0) return next((i for i, tab in enumerate(self.tabs) if tab.name == tab_name), 0)
def render(self) -> RenderableType: def render(self, style: Style) -> RenderableType:
return TabsRenderable( return TabsRenderable(
self.tabs, self.tabs,
tab_padding=self.tab_padding, tab_padding=self.tab_padding,

View File

@@ -90,14 +90,48 @@ def test_did_you_mean_for_css_property_names(
_, help_text = err.value.errors.rules[0].errors[0] # type: Any, HelpText _, help_text = err.value.errors.rules[0].errors[0] # type: Any, HelpText
displayed_css_property_name = css_property_name.replace("_", "-") displayed_css_property_name = css_property_name.replace("_", "-")
assert ( expected_summary = f"Invalid CSS property [i]{displayed_css_property_name}[/]"
help_text.summary == f"Invalid CSS property [i]{displayed_css_property_name}[/]" if expected_property_name_suggestion:
expected_summary += f'. Did you mean "{expected_property_name_suggestion}"?'
assert help_text.summary == expected_summary
@pytest.mark.parametrize(
"css_property_name,css_property_value,expected_color_suggestion",
[
["color", "blu", "blue"],
["background", "chartruse", "chartreuse"],
["tint", "ansi_whi", "ansi_white"],
["scrollbar-color", "transprnt", "transparent"],
["color", "xkcd", None],
],
)
def test_did_you_mean_for_color_names(
css_property_name: str, css_property_value: str, expected_color_suggestion
):
stylesheet = Stylesheet()
css = """
* {
border: blue;
${PROPERTY}: ${VALUE};
}
""".replace(
"${PROPERTY}", css_property_name
).replace(
"${VALUE}", css_property_value
) )
expected_bullets_length = 1 if expected_property_name_suggestion else 0 stylesheet.add_source(css)
assert len(help_text.bullets) == expected_bullets_length with pytest.raises(StylesheetParseError) as err:
if expected_property_name_suggestion is not None: stylesheet.parse()
expected_suggestion_message = (
f'Did you mean "{expected_property_name_suggestion}"?' _, help_text = err.value.errors.rules[0].errors[0] # type: Any, HelpText
) displayed_css_property_name = css_property_name.replace("_", "-")
assert help_text.bullets[0].markup == expected_suggestion_message expected_error_summary = (
f"Invalid value for the [i]{displayed_css_property_name}[/] property"
)
if expected_color_suggestion is not None:
expected_error_summary += f'. Did you mean "{expected_color_suggestion}"?'
assert help_text.summary == expected_error_summary

View File

@@ -1,44 +1,54 @@
from textual._region_group import regions_to_ranges, InlineRange from textual._compositor import Compositor
from textual.geometry import Region from textual.geometry import Region
def test_regions_to_ranges_no_regions(): def test_regions_to_ranges_no_regions():
assert list(regions_to_ranges([])) == [] assert list(Compositor._regions_to_spans([])) == []
def test_regions_to_ranges_single_region(): def test_regions_to_ranges_single_region():
regions = [Region(0, 0, 3, 2)] regions = [Region(0, 0, 3, 2)]
assert list(regions_to_ranges(regions)) == [InlineRange(0, 0, 2), InlineRange(1, 0, 2)] assert list(Compositor._regions_to_spans(regions)) == [
(0, 0, 3),
(1, 0, 3),
]
def test_regions_to_ranges_partially_overlapping_regions(): def test_regions_to_ranges_partially_overlapping_regions():
regions = [Region(0, 0, 2, 2), Region(1, 1, 2, 2)] regions = [Region(0, 0, 2, 2), Region(1, 1, 2, 2)]
assert list(regions_to_ranges(regions)) == [ assert list(Compositor._regions_to_spans(regions)) == [
InlineRange(0, 0, 1), InlineRange(1, 0, 2), InlineRange(2, 1, 2), (0, 0, 2),
(1, 0, 3),
(2, 1, 3),
] ]
def test_regions_to_ranges_fully_overlapping_regions(): def test_regions_to_ranges_fully_overlapping_regions():
regions = [Region(1, 1, 3, 3), Region(2, 2, 1, 1), Region(0, 2, 3, 1)] regions = [Region(1, 1, 3, 3), Region(2, 2, 1, 1), Region(0, 2, 3, 1)]
assert list(regions_to_ranges(regions)) == [ assert list(Compositor._regions_to_spans(regions)) == [
InlineRange(1, 1, 3), InlineRange(2, 0, 3), InlineRange(3, 1, 3) (1, 1, 4),
(2, 0, 4),
(3, 1, 4),
] ]
def test_regions_to_ranges_disjoint_regions_different_lines(): def test_regions_to_ranges_disjoint_regions_different_lines():
regions = [Region(0, 0, 2, 1), Region(2, 2, 2, 1)] regions = [Region(0, 0, 2, 1), Region(2, 2, 2, 1)]
assert list(regions_to_ranges(regions)) == [InlineRange(0, 0, 1), InlineRange(2, 2, 3)] assert list(Compositor._regions_to_spans(regions)) == [(0, 0, 2), (2, 2, 4)]
def test_regions_to_ranges_disjoint_regions_same_line(): def test_regions_to_ranges_disjoint_regions_same_line():
regions = [Region(0, 0, 1, 2), Region(2, 0, 1, 1)] regions = [Region(0, 0, 1, 2), Region(2, 0, 1, 1)]
assert list(regions_to_ranges(regions)) == [ assert list(Compositor._regions_to_spans(regions)) == [
InlineRange(0, 0, 0), InlineRange(0, 2, 2), InlineRange(1, 0, 0) (0, 0, 1),
(0, 2, 3),
(1, 0, 1),
] ]
def test_regions_to_ranges_directly_adjacent_ranges_merged(): def test_regions_to_ranges_directly_adjacent_ranges_merged():
regions = [Region(0, 0, 1, 2), Region(1, 0, 1, 2)] regions = [Region(0, 0, 1, 2), Region(1, 0, 1, 2)]
assert list(regions_to_ranges(regions)) == [ assert list(Compositor._regions_to_spans(regions)) == [
InlineRange(0, 0, 1), InlineRange(1, 0, 1) (0, 0, 2),
(1, 0, 2),
] ]

View File

@@ -114,6 +114,17 @@ def test_region_null():
assert not Region() assert not Region()
def test_region_from_union():
with pytest.raises(ValueError):
Region.from_union([])
regions = [
Region(10, 20, 30, 40),
Region(15, 25, 5, 5),
Region(30, 25, 20, 10),
]
assert Region.from_union(regions) == Region(10, 20, 40, 40)
def test_region_from_origin(): def test_region_from_origin():
assert Region.from_origin(Offset(3, 4), (5, 6)) == Region(3, 4, 5, 6) assert Region.from_origin(Offset(3, 4), (5, 6)) == Region(3, 4, 5, 6)
@@ -132,6 +143,18 @@ def test_region_origin():
assert Region(1, 2, 3, 4).origin == Offset(1, 2) assert Region(1, 2, 3, 4).origin == Offset(1, 2)
def test_region_bottom_left():
assert Region(1, 2, 3, 4).bottom_left == Offset(1, 6)
def test_region_top_right():
assert Region(1, 2, 3, 4).top_right == Offset(4, 2)
def test_region_bottom_right():
assert Region(1, 2, 3, 4).bottom_right == Offset(4, 6)
def test_region_add(): def test_region_add():
assert Region(1, 2, 3, 4) + (10, 20) == Region(11, 22, 3, 4) assert Region(1, 2, 3, 4) + (10, 20) == Region(11, 22, 3, 4)
with pytest.raises(TypeError): with pytest.raises(TypeError):

View File

@@ -150,12 +150,10 @@ async def test_composition_of_vertical_container_with_children(
expected_screen_size = Size(*screen_size) expected_screen_size = Size(*screen_size)
async with app.in_running_state(): async with app.in_running_state():
app.log_tree()
# root widget checks: # root widget checks:
root_widget = cast(Widget, app.get_child("root")) root_widget = cast(Widget, app.get_child("root"))
assert root_widget.size == expected_screen_size assert root_widget.size == expected_screen_size
root_widget_region = app.screen.get_widget_region(root_widget) root_widget_region = app.screen.find_widget(root_widget).region
assert root_widget_region == ( assert root_widget_region == (
0, 0,
0, 0,

View File

@@ -0,0 +1,117 @@
from __future__ import annotations
import sys
from typing import Sequence, cast
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover
import pytest
from sandbox.vertical_container import VerticalContainer
from tests.utilities.test_app import AppTest
from textual.app import ComposeResult
from textual.geometry import Size
from textual.widget import Widget
from textual.widgets import Placeholder
SCREEN_SIZE = Size(100, 30)
@pytest.mark.skip("flaky test")
@pytest.mark.asyncio
@pytest.mark.integration_test # this is a slow test, we may want to skip them in some contexts
@pytest.mark.parametrize(
(
"screen_size",
"placeholders_count",
"scroll_to_placeholder_id",
"scroll_to_animate",
"waiting_duration",
"last_screen_expected_placeholder_ids",
"last_screen_expected_out_of_viewport_placeholder_ids",
),
(
[SCREEN_SIZE, 10, None, None, 0.01, (0, 1, 2, 3, 4), "others"],
[SCREEN_SIZE, 10, "placeholder_3", False, 0.01, (0, 1, 2, 3, 4), "others"],
[SCREEN_SIZE, 10, "placeholder_5", False, 0.01, (1, 2, 3, 4, 5), "others"],
[SCREEN_SIZE, 10, "placeholder_7", False, 0.01, (3, 4, 5, 6, 7), "others"],
[SCREEN_SIZE, 10, "placeholder_9", False, 0.01, (5, 6, 7, 8, 9), "others"],
# N.B. Scroll duration is hard-coded to 0.2 in the `scroll_to_widget` method atm
# Waiting for this duration should allow us to see the scroll finished:
[SCREEN_SIZE, 10, "placeholder_9", True, 0.21, (5, 6, 7, 8, 9), "others"],
# After having waited for approximately half of the scrolling duration, we should
# see the middle Placeholders as we're scrolling towards the last of them.
# The state of the screen at this "halfway there" timing looks to not be deterministic though,
# depending on the environment - so let's only assert stuff for the middle placeholders
# and the first and last ones, but without being too specific about the others:
[SCREEN_SIZE, 10, "placeholder_9", True, 0.1, (5, 6, 7), (1, 2, 9)],
),
)
async def test_scroll_to_widget(
screen_size: Size,
placeholders_count: int,
scroll_to_animate: bool | None,
scroll_to_placeholder_id: str | None,
waiting_duration: float | None,
last_screen_expected_placeholder_ids: Sequence[int],
last_screen_expected_out_of_viewport_placeholder_ids: Sequence[int]
| Literal["others"],
):
class MyTestApp(AppTest):
CSS = """
Placeholder {
height: 5; /* minimal height to see the name of a Placeholder */
}
"""
def compose(self) -> ComposeResult:
placeholders = [
Placeholder(id=f"placeholder_{i}", name=f"Placeholder #{i}")
for i in range(placeholders_count)
]
yield VerticalContainer(*placeholders, id="root")
app = MyTestApp(size=screen_size, test_name="scroll_to_widget")
async with app.in_running_state(waiting_duration_post_yield=waiting_duration or 0):
if scroll_to_placeholder_id:
target_widget_container = cast(Widget, app.query("#root").first())
target_widget = cast(
Widget, app.query(f"#{scroll_to_placeholder_id}").first()
)
target_widget_container.scroll_to_widget(
target_widget, animate=scroll_to_animate
)
last_display_capture = app.last_display_capture
placeholders_visibility_by_id = {
id_: f"placeholder_{id_}" in last_display_capture
for id_ in range(placeholders_count)
}
# Let's start by checking placeholders that should be visible:
for placeholder_id in last_screen_expected_placeholder_ids:
assert (
placeholders_visibility_by_id[placeholder_id] is True
), f"Placeholder '{placeholder_id}' should be visible but isn't"
# Ok, now for placeholders that should *not* be visible:
if last_screen_expected_out_of_viewport_placeholder_ids == "others":
# We're simply going to check that all the placeholders that are not in
# `last_screen_expected_placeholder_ids` are not on the screen:
last_screen_expected_out_of_viewport_placeholder_ids = sorted(
tuple(
set(range(placeholders_count))
- set(last_screen_expected_placeholder_ids)
)
)
for placeholder_id in last_screen_expected_out_of_viewport_placeholder_ids:
assert (
placeholders_visibility_by_id[placeholder_id] is False
), f"Placeholder '{placeholder_id}' should not be visible but is"

View File

@@ -1,7 +1,5 @@
from contextlib import nullcontext as does_not_raise
from decimal import Decimal
import pytest import pytest
from rich.style import Style
from textual.app import App from textual.app import App
from textual.css.errors import StyleValueError from textual.css.errors import StyleValueError
@@ -41,7 +39,7 @@ def test_widget_content_width():
self.text = text self.text = text
super().__init__(id=id) super().__init__(id=id)
def render(self) -> str: def render(self, style: Style) -> str:
return self.text return self.text
widget1 = TextWidget("foo", id="widget1") widget1 = TextWidget("foo", id="widget1")

View File

@@ -4,9 +4,10 @@ import asyncio
import contextlib import contextlib
import io import io
from pathlib import Path from pathlib import Path
from typing import AsyncContextManager from typing import AsyncContextManager, cast
from rich.console import Console
from rich.console import Console, Capture
from textual import events from textual import events
from textual.app import App, ReturnType, ComposeResult from textual.app import App, ReturnType, ComposeResult
from textual.driver import Driver from textual.driver import Driver
@@ -16,6 +17,9 @@ from textual.geometry import Size
# N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc, # N.B. These classes would better be named TestApp/TestConsole/TestDriver/etc,
# but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/ # but it makes pytest emit warning as it will try to collect them as classes containing test cases :-/
# This value is also hard-coded in Textual's `App` class:
CLEAR_SCREEN_SEQUENCE = "\x1bP=1s\x1b\\"
class AppTest(App): class AppTest(App):
def __init__( def __init__(
@@ -25,7 +29,7 @@ class AppTest(App):
size: Size, size: Size,
log_verbosity: int = 2, log_verbosity: int = 2,
): ):
# will log in "/tests/test.[test name].log": # Tests will log in "/tests/test.[test name].log":
log_path = Path(__file__).parent.parent / f"test.{test_name}.log" log_path = Path(__file__).parent.parent / f"test.{test_name}.log"
super().__init__( super().__init__(
driver_class=DriverTest, driver_class=DriverTest,
@@ -33,6 +37,11 @@ class AppTest(App):
log_verbosity=log_verbosity, log_verbosity=log_verbosity,
log_color_system="256", log_color_system="256",
) )
# We need this so the `CLEAR_SCREEN_SEQUENCE` is always sent for a screen refresh,
# whatever the environment:
self._sync_available = True
self._size = size self._size = size
self._console = ConsoleTest(width=size.width, height=size.height) self._console = ConsoleTest(width=size.width, height=size.height)
self._error_console = ConsoleTest(width=size.width, height=size.height) self._error_console = ConsoleTest(width=size.width, height=size.height)
@@ -49,16 +58,18 @@ class AppTest(App):
def in_running_state( def in_running_state(
self, self,
*, *,
initialisation_timeout: float = 0.1, waiting_duration_after_initialisation: float = 0.1,
) -> AsyncContextManager[Capture]: waiting_duration_post_yield: float = 0,
) -> AsyncContextManager:
async def run_app() -> None: async def run_app() -> None:
await self.process_messages() await self.process_messages()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def get_running_state_context_manager(): async def get_running_state_context_manager():
self._set_active()
run_task = asyncio.create_task(run_app()) run_task = asyncio.create_task(run_app())
timeout_before_yielding_task = asyncio.create_task( timeout_before_yielding_task = asyncio.create_task(
asyncio.sleep(initialisation_timeout) asyncio.sleep(waiting_duration_after_initialisation)
) )
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
( (
@@ -69,10 +80,11 @@ class AppTest(App):
) )
if run_task in done or run_task not in pending: if run_task in done or run_task not in pending:
raise RuntimeError( raise RuntimeError(
"TestApp is no longer return after its initialization period" "TestApp is no longer running after its initialization period"
) )
with self.console.capture() as capture: yield
yield capture if waiting_duration_post_yield:
await asyncio.sleep(waiting_duration_post_yield)
assert not run_task.done() assert not run_task.done()
await self.shutdown() await self.shutdown()
@@ -83,6 +95,18 @@ class AppTest(App):
"Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`" "Use `async with my_test_app.in_running_state()` rather than `my_test_app.run()`"
) )
@property
def total_capture(self) -> str | None:
return self.console.file.getvalue()
@property
def last_display_capture(self) -> str | None:
total_capture = self.total_capture
if not total_capture:
return None
last_display_start_index = total_capture.rindex(CLEAR_SCREEN_SEQUENCE)
return total_capture[last_display_start_index:]
@property @property
def console(self) -> ConsoleTest: def console(self) -> ConsoleTest:
return self._console return self._console
@@ -110,10 +134,18 @@ class ConsoleTest(Console):
file=file, file=file,
width=width, width=width,
height=height, height=height,
force_terminal=True, force_terminal=False,
legacy_windows=False, legacy_windows=False,
) )
@property
def file(self) -> io.StringIO:
return cast(io.StringIO, self._file)
@property
def is_dumb_terminal(self) -> bool:
return False
class DriverTest(Driver): class DriverTest(Driver):
def start_application_mode(self) -> None: def start_application_mode(self) -> None: