diff --git a/sandbox/basic.css b/sandbox/basic.css index 240b4ca28..54983b0c1 100644 --- a/sandbox/basic.css +++ b/sandbox/basic.css @@ -105,11 +105,10 @@ Tweet { .code { height: auto; - - + } -} + TweetHeader { height:1; diff --git a/src/textual/_cache.py b/src/textual/_cache.py index ec9f67889..b06040919 100644 --- a/src/textual/_cache.py +++ b/src/textual/_cache.py @@ -1,13 +1,23 @@ -""" - -LRU Cache operation borrowed from Rich. - -This may become more sophisticated in Textual, but hopefully remain simple in Rich. - -""" - +import sys +from collections import deque +from functools import wraps from threading import Lock -from typing import Dict, Generic, List, Optional, TypeVar, Union, overload +from typing import ( + Callable, + Deque, + Dict, + Generic, + List, + Optional, + TypeVar, + Union, + overload, +) + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec CacheKey = TypeVar("CacheKey") CacheValue = TypeVar("CacheValue") @@ -40,6 +50,12 @@ class LRUCache(Generic[CacheKey, CacheValue]): def __len__(self) -> int: return len(self.cache) + def clear(self) -> None: + """Clear the cache.""" + with self._lock: + self.cache.clear() + self.root = [] + def set(self, key: CacheKey, value: CacheValue) -> None: """Set a value. @@ -122,3 +138,48 @@ class LRUCache(Generic[CacheKey, CacheValue]): def __contains__(self, key: CacheKey) -> bool: return key in self.cache + + +P = ParamSpec("P") +T = TypeVar("T") + + +def fifo_cache(maxsize: int) -> Callable[[Callable[P, T]], Callable[P, T]]: + """A First In First Out cache. + + Args: + maxsize (int): Maximum size of the cache + + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + queue: Deque[object] = deque() + cache: Dict[object, T] = {} + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + return cache[args] + except KeyError: + assert not kwargs, "Will not work with keyword arguments!" + cache[args] = result = func(*args) + queue.append(args) + if len(queue) > maxsize: + del cache[queue.popleft()] + return result + + return wrapper + + return decorator + + +@fifo_cache(10) +def double(n: int) -> int: + return n * n + + +print(double(1)) +print(double(2)) +print(double(2)) +print(double(3)) +print(double(4)) diff --git a/src/textual/_compositor.py b/src/textual/_compositor.py index b56b5f2cb..779102884 100644 --- a/src/textual/_compositor.py +++ b/src/textual/_compositor.py @@ -445,7 +445,6 @@ class Compositor: x -= region.x y -= region.y - # lines = widget.render_lines((y, y + 1), (0, region.width)) lines = widget.render_lines(Region(0, y, region.width, 1)) if not lines: @@ -575,6 +574,7 @@ class Compositor: ] return segment_lines + @timer("render") def render(self, full: bool = False) -> RenderableType | None: """Render a layout. diff --git a/src/textual/app.py b/src/textual/app.py index a747eeac6..d6e8b3d7c 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -194,7 +194,7 @@ class App(Generic[ReturnType], DOMNode): self.design = DEFAULT_COLORS self.stylesheet = Stylesheet(variables=self.get_css_variables()) - self._require_styles_update = False + self._require_stylesheet_update = False self.css_path = css_path or self.CSS_PATH self.registry: set[MessagePump] = set() @@ -584,7 +584,7 @@ class App(Generic[ReturnType], DOMNode): Should be called whenever CSS classes / pseudo classes change. """ - self._require_styles_update = True + self._require_stylesheet_update = True self.check_idle() def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: @@ -817,9 +817,9 @@ class App(Generic[ReturnType], DOMNode): async def on_idle(self) -> None: """Perform actions when there are no messages in the queue.""" - if self._require_styles_update: - await self.post_message(messages.StylesUpdated(self)) - self._require_styles_update = False + if self._require_stylesheet_update: + self._require_stylesheet_update = False + self.stylesheet.update(self, animate=True) def _register_child(self, parent: DOMNode, child: DOMNode) -> bool: if child not in self.registry: @@ -1135,8 +1135,8 @@ class App(Generic[ReturnType], DOMNode): async def action_toggle_class(self, selector: str, class_name: str) -> None: self.screen.query(selector).toggle_class(class_name) - async def handle_styles_updated(self, message: messages.StylesUpdated) -> None: - self.stylesheet.update(self, animate=True) + # async def handle_styles_updated(self, message: messages.StylesUpdated) -> None: + # self.stylesheet.update(self, animate=True) def handle_terminal_supports_synchronized_output( self, message: messages.TerminalSupportsSynchronizedOutput diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index 052b5711c..b62da8cd2 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -25,6 +25,7 @@ from .tokenize import tokenize_values, Token from .tokenizer import TokenizeError from .types import Specificity3, Specificity4 from ..dom import DOMNode +from .. import messages class StylesheetParseError(StylesheetError): @@ -375,6 +376,8 @@ class Stylesheet: for key in modified_rule_keys: setattr(base_styles, key, get_rule(key)) + node.post_message_no_wait(messages.StylesUpdated(sender=node)) + def update(self, root: DOMNode, animate: bool = False) -> None: """Update a node and its children.""" apply = self.apply diff --git a/src/textual/dom.py b/src/textual/dom.py index ceb13ac3b..1dfb0a395 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -428,9 +428,6 @@ class DOMNode(MessagePump): node.set_dirty() node._layout_required = True - def on_style_change(self) -> None: - pass - def add_child(self, node: DOMNode) -> None: """Add a new child node. diff --git a/src/textual/geometry.py b/src/textual/geometry.py index f9d5da54a..472d3bbc2 100644 --- a/src/textual/geometry.py +++ b/src/textual/geometry.py @@ -7,7 +7,6 @@ Functions and classes to manage terminal geometry (anything involving coordinate from __future__ import annotations from functools import lru_cache - from typing import Any, cast, Collection, NamedTuple, Tuple, TypeAlias, Union, TypeVar SpacingDimensions: TypeAlias = Union[ diff --git a/src/textual/messages.py b/src/textual/messages.py index 8bdbf17d4..8af6752a6 100644 --- a/src/textual/messages.py +++ b/src/textual/messages.py @@ -57,7 +57,7 @@ class Prompt(Message, system=True): """Used to 'wake up' an event loop.""" def can_replace(self, message: Message) -> bool: - return isinstance(message, StylesUpdated) + return isinstance(message, Prompt) class TerminalSupportsSynchronizedOutput(Message): diff --git a/src/textual/widget.py b/src/textual/widget.py index e8d888bdb..d43f8fa5c 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -854,10 +854,6 @@ class Widget(DOMNode): """Update from CSS if has focus state changes.""" self.app.update_styles() - def on_style_change(self) -> None: - self.set_dirty() - self.check_idle() - def size_updated( self, size: Size, virtual_size: Size, container_size: Size ) -> None: diff --git a/src/textual/widgets/_datatable.py b/src/textual/widgets/_datatable.py index 5f1458040..b378980d5 100644 --- a/src/textual/widgets/_datatable.py +++ b/src/textual/widgets/_datatable.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import ClassVar, Generic, TypeVar, cast +from typing import Callable, ClassVar, Generic, TypeVar, cast from rich.console import RenderableType from rich.padding import Padding +from rich.protocol import is_renderable from rich.segment import Segment from rich.style import Style from rich.text import Text, TextType @@ -17,9 +18,20 @@ from ..reactive import Reactive from .._profile import timer from ..scroll_view import ScrollView from ..widget import Widget +from .. import messages CellType = TypeVar("CellType") +CellFormatter = Callable[[object], RenderableType | None] + + +def default_cell_formatter(obj: object) -> RenderableType | None: + if isinstance(obj, str): + return Text.from_markup(obj) + if not is_renderable(obj): + raise TypeError("Table cell contains {obj!r} which is not renderable") + return cast(RenderableType, obj) + @dataclass class Column: @@ -103,14 +115,36 @@ class DataTable(ScrollView, Generic[CellType]): self.line_contents: list[str] = [] - self._cells: dict[int, list[Cell]] = {} - self._cell_render_cache: dict[tuple[int, int], Lines] = LRUCache(10000) + self._row_render_cache: LRUCache[int, tuple[Lines, Lines]] = LRUCache(1000) + self._cell_render_cache: LRUCache[tuple[int, int, Style], Lines] = LRUCache( + 10000 + ) + self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] = LRUCache( + 1000 + ) show_header = Reactive(True) fixed_rows = Reactive(1) fixed_columns = Reactive(1) zebra_stripes = Reactive(False) + def _clear_caches(self) -> None: + self._row_render_cache.clear() + self._cell_render_cache.clear() + self._line_cache.clear() + + async def handle_styles_updated(self, message: messages.StylesUpdated) -> None: + self._clear_caches() + + def watch_show_header(self, show_header: bool) -> None: + self._clear_caches() + + def watch_fixed_rows(self, fixed_rows: int) -> None: + self._clear_caches() + + def watch_zebra_stripes(self, zebra_stripes: int) -> None: + self._clear_caches() + def _update_dimensions(self) -> None: max_width = sum(column.width for column in self.columns) self.virtual_size = Size(max_width, len(self.data) + self.show_header) @@ -124,93 +158,103 @@ class DataTable(ScrollView, Generic[CellType]): def add_row(self, *cells: CellType, height: int = 1) -> None: row_index = self.row_count self.data[row_index] = list(cells) - self.rows[row_index] = Row( - row_index, - height=height, - cell_renderables=[ - Text.from_markup(cell) if isinstance(cell, str) else cell - for cell in cells - ], - ) + self.rows[row_index] = Row(row_index, height=height) self.row_count += 1 self._update_dimensions() self.refresh() - def get_row(self, y: int) -> list[CellType | Text]: + def get_row(self, row_index: int) -> list[RenderableType]: - if y == 0 and self.show_header: + if row_index == 0 and self.show_header: row = [column.label for column in self.columns] return row - data_offset = y - 1 if self.show_header else 0 + data_offset = row_index - 1 if self.show_header else 0 data = self.data.get(data_offset) + empty = Text() if data is None: - return [Text() for column in self.columns] + return [empty for column in self.columns] else: - return self.rows[data_offset].cell_renderables + return [default_cell_formatter(datum) or empty for datum in data] - def _render_cell(self, y: int, column: Column) -> Lines: + def _render_cell(self, row_index: int, column: Column, style: Style) -> Lines: - style = Style.from_meta({"y": y, "column": column.index}) - - cell_key = (y, column.index) + cell_key = (row_index, column.index, style) if cell_key not in self._cell_render_cache: - cell = self.get_row(y)[column.index] + style += Style.from_meta({"row": row_index, "column": column.index}) + cell = self.get_row(row_index)[column.index] lines = self.app.console.render_lines( Padding(cell, (0, 1)), self.app.console.options.update_dimensions(column.width, 1), style=style, ) self._cell_render_cache[cell_key] = lines - return self._cell_render_cache[cell_key] - def _render_line(self, y: int, x1: int, x2: int) -> list[Segment]: + def _render_row(self, row_index: int, base_style: Style) -> tuple[Lines, Lines]: + if row_index in self._row_render_cache: + return self._row_render_cache[row_index] - width = self.content_region.width + if self.fixed_columns: + fixed_style = self.component_styles["datatable--fixed"].node.rich_style + fixed_style += Style.from_meta({"fixed": True}) - cell_segments: list[list[Segment]] = [] - rendered_width = 0 - for column in self.columns: - lines = self._render_cell(y, column) - rendered_width += column.width - cell_segments.append(lines[0]) + fixed_row = [ + self._render_cell(row_index, column, fixed_style)[0] + for column in self.columns[: self.fixed_columns] + ] + else: + fixed_row = [] - base_style = self.rich_style - - fixed_style = self.component_styles[ - "datatable--fixed" - ].node.rich_style + Style.from_meta({"fixed": True}) - header_style = self.component_styles[ - "datatable--header" - ].node.rich_style + Style.from_meta({"header": True}) - - fixed: list[Segment] = sum(cell_segments[: self.fixed_columns], start=[]) - fixed_width = sum(column.width for column in self.columns[: self.fixed_columns]) - - fixed = list(Segment.apply_style(fixed, fixed_style)) - - line: list[Segment] = sum(cell_segments, start=[]) - - row_style = base_style - if y == 0: - segments = fixed + line_crop(line, x1 + fixed_width, x2, width) - line = Segment.adjust_line_length(segments, width) + if row_index == 0 and self.show_header: + row_style = self.component_styles["datatable--header"].node.rich_style else: if self.zebra_stripes: component_row_style = ( - "datatable--odd-row" if y % 2 else "datatable--even-row" + "datatable--odd-row" if row_index % 2 else "datatable--even-row" ) row_style = self.component_styles[component_row_style].node.rich_style + else: + row_style = base_style - line = list(Segment.apply_style(line, row_style)) - segments = fixed + line_crop(line, x1 + fixed_width, x2, width) - line = Segment.adjust_line_length(segments, width, style=base_style) + scrollable_row = [ + self._render_cell(row_index, column, row_style)[0] + for column in self.columns + ] - if y == 0 and self.show_header: - line = list(Segment.apply_style(line, header_style)) + row_pair = (fixed_row, scrollable_row) + self._row_render_cache[row_index] = row_pair + return row_pair - return line + def _render_line( + self, y: int, x1: int, x2: int, base_style: Style + ) -> list[Segment]: + + width = self.content_region.width + + cache_key = (y, x1, x2, width) + if cache_key in self._line_cache: + return self._line_cache[cache_key] + + row_index = y + + fixed, scrollable = self._render_row(row_index, base_style) + fixed_width = sum(column.width for column in self.columns[: self.fixed_columns]) + + fixed_line: list[Segment] = sum(fixed, start=[]) + scrollable_line: list[Segment] = sum(scrollable, start=[]) + + segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width) + + # line = Segment.adjust_line_length(segments, width, style=base_style) + remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width))) + if remaining_width > 0: + segments.append(Segment(" " * remaining_width, base_style)) + elif remaining_width < 0: + segments = Segment.adjust_line_length(segments, width, style=base_style) + + self._line_cache[cache_key] = segments + return segments @timer("render_lines") def render_lines(self, crop: Region) -> Lines: @@ -218,8 +262,12 @@ class DataTable(ScrollView, Generic[CellType]): scroll_x, scroll_y = self.scroll_offset x1, y1, x2, y2 = crop.translate(scroll_x, scroll_y).corners - fixed_lines = [self._render_line(y, x1, x2) for y in range(0, self.fixed_rows)] - lines = [self._render_line(y, x1, x2) for y in range(y1, y2)] + base_style = self.rich_style + + fixed_lines = [ + self._render_line(y, x1, x2, base_style) for y in range(0, self.fixed_rows) + ] + lines = [self._render_line(y, x1, x2, base_style) for y in range(y1, y2)] for fixed_line, y in zip(fixed_lines, range(y1, y2)): if y - scroll_y == 0: