render cache

This commit is contained in:
Will McGugan
2022-06-18 18:38:50 +01:00
parent f4c12704fe
commit bb1449315e
10 changed files with 191 additions and 88 deletions

View File

@@ -105,11 +105,10 @@ Tweet {
.code { .code {
height: auto; height: auto;
} }
}
TweetHeader { TweetHeader {
height:1; height:1;

View File

@@ -1,13 +1,23 @@
""" import sys
from collections import deque
LRU Cache operation borrowed from Rich. from functools import wraps
This may become more sophisticated in Textual, but hopefully remain simple in Rich.
"""
from threading import Lock from threading import Lock
from typing import Dict, Generic, List, Optional, TypeVar, Union, overload from typing import (
Callable,
Deque,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
overload,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
CacheKey = TypeVar("CacheKey") CacheKey = TypeVar("CacheKey")
CacheValue = TypeVar("CacheValue") CacheValue = TypeVar("CacheValue")
@@ -40,6 +50,12 @@ class LRUCache(Generic[CacheKey, CacheValue]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def clear(self) -> None:
"""Clear the cache."""
with self._lock:
self.cache.clear()
self.root = []
def set(self, key: CacheKey, value: CacheValue) -> None: def set(self, key: CacheKey, value: CacheValue) -> None:
"""Set a value. """Set a value.
@@ -122,3 +138,48 @@ class LRUCache(Generic[CacheKey, CacheValue]):
def __contains__(self, key: CacheKey) -> bool: def __contains__(self, key: CacheKey) -> bool:
return key in self.cache return key in self.cache
P = ParamSpec("P")
T = TypeVar("T")
def fifo_cache(maxsize: int) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""A First In First Out cache.
Args:
maxsize (int): Maximum size of the cache
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
queue: Deque[object] = deque()
cache: Dict[object, T] = {}
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return cache[args]
except KeyError:
assert not kwargs, "Will not work with keyword arguments!"
cache[args] = result = func(*args)
queue.append(args)
if len(queue) > maxsize:
del cache[queue.popleft()]
return result
return wrapper
return decorator
@fifo_cache(10)
def double(n: int) -> int:
return n * n
print(double(1))
print(double(2))
print(double(2))
print(double(3))
print(double(4))

View File

@@ -445,7 +445,6 @@ class Compositor:
x -= region.x x -= region.x
y -= region.y y -= region.y
# lines = widget.render_lines((y, y + 1), (0, region.width))
lines = widget.render_lines(Region(0, y, region.width, 1)) lines = widget.render_lines(Region(0, y, region.width, 1))
if not lines: if not lines:
@@ -575,6 +574,7 @@ class Compositor:
] ]
return segment_lines return segment_lines
@timer("render")
def render(self, full: bool = False) -> RenderableType | None: def render(self, full: bool = False) -> RenderableType | None:
"""Render a layout. """Render a layout.

View File

@@ -194,7 +194,7 @@ class App(Generic[ReturnType], DOMNode):
self.design = DEFAULT_COLORS self.design = DEFAULT_COLORS
self.stylesheet = Stylesheet(variables=self.get_css_variables()) self.stylesheet = Stylesheet(variables=self.get_css_variables())
self._require_styles_update = False self._require_stylesheet_update = False
self.css_path = css_path or self.CSS_PATH self.css_path = css_path or self.CSS_PATH
self.registry: set[MessagePump] = set() self.registry: set[MessagePump] = set()
@@ -584,7 +584,7 @@ class App(Generic[ReturnType], DOMNode):
Should be called whenever CSS classes / pseudo classes change. Should be called whenever CSS classes / pseudo classes change.
""" """
self._require_styles_update = True self._require_stylesheet_update = True
self.check_idle() self.check_idle()
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None:
@@ -817,9 +817,9 @@ class App(Generic[ReturnType], DOMNode):
async def on_idle(self) -> None: async def on_idle(self) -> None:
"""Perform actions when there are no messages in the queue.""" """Perform actions when there are no messages in the queue."""
if self._require_styles_update: if self._require_stylesheet_update:
await self.post_message(messages.StylesUpdated(self)) self._require_stylesheet_update = False
self._require_styles_update = False self.stylesheet.update(self, animate=True)
def _register_child(self, parent: DOMNode, child: DOMNode) -> bool: def _register_child(self, parent: DOMNode, child: DOMNode) -> bool:
if child not in self.registry: if child not in self.registry:
@@ -1135,8 +1135,8 @@ class App(Generic[ReturnType], DOMNode):
async def action_toggle_class(self, selector: str, class_name: str) -> None: async def action_toggle_class(self, selector: str, class_name: str) -> None:
self.screen.query(selector).toggle_class(class_name) self.screen.query(selector).toggle_class(class_name)
async def handle_styles_updated(self, message: messages.StylesUpdated) -> None: # async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
self.stylesheet.update(self, animate=True) # self.stylesheet.update(self, animate=True)
def handle_terminal_supports_synchronized_output( def handle_terminal_supports_synchronized_output(
self, message: messages.TerminalSupportsSynchronizedOutput self, message: messages.TerminalSupportsSynchronizedOutput

View File

@@ -25,6 +25,7 @@ from .tokenize import tokenize_values, Token
from .tokenizer import TokenizeError from .tokenizer import TokenizeError
from .types import Specificity3, Specificity4 from .types import Specificity3, Specificity4
from ..dom import DOMNode from ..dom import DOMNode
from .. import messages
class StylesheetParseError(StylesheetError): class StylesheetParseError(StylesheetError):
@@ -375,6 +376,8 @@ class Stylesheet:
for key in modified_rule_keys: for key in modified_rule_keys:
setattr(base_styles, key, get_rule(key)) setattr(base_styles, key, get_rule(key))
node.post_message_no_wait(messages.StylesUpdated(sender=node))
def update(self, root: DOMNode, animate: bool = False) -> None: def update(self, root: DOMNode, animate: bool = False) -> None:
"""Update a node and its children.""" """Update a node and its children."""
apply = self.apply apply = self.apply

View File

@@ -428,9 +428,6 @@ class DOMNode(MessagePump):
node.set_dirty() node.set_dirty()
node._layout_required = True node._layout_required = True
def on_style_change(self) -> None:
pass
def add_child(self, node: DOMNode) -> None: def add_child(self, node: DOMNode) -> None:
"""Add a new child node. """Add a new child node.

View File

@@ -7,7 +7,6 @@ Functions and classes to manage terminal geometry (anything involving coordinate
from __future__ import annotations from __future__ import annotations
from functools import lru_cache from functools import lru_cache
from typing import Any, cast, Collection, NamedTuple, Tuple, TypeAlias, Union, TypeVar from typing import Any, cast, Collection, NamedTuple, Tuple, TypeAlias, Union, TypeVar
SpacingDimensions: TypeAlias = Union[ SpacingDimensions: TypeAlias = Union[

View File

@@ -57,7 +57,7 @@ class Prompt(Message, system=True):
"""Used to 'wake up' an event loop.""" """Used to 'wake up' an event loop."""
def can_replace(self, message: Message) -> bool: def can_replace(self, message: Message) -> bool:
return isinstance(message, StylesUpdated) return isinstance(message, Prompt)
class TerminalSupportsSynchronizedOutput(Message): class TerminalSupportsSynchronizedOutput(Message):

View File

@@ -854,10 +854,6 @@ class Widget(DOMNode):
"""Update from CSS if has focus state changes.""" """Update from CSS if has focus state changes."""
self.app.update_styles() self.app.update_styles()
def on_style_change(self) -> None:
self.set_dirty()
self.check_idle()
def size_updated( def size_updated(
self, size: Size, virtual_size: Size, container_size: Size self, size: Size, virtual_size: Size, container_size: Size
) -> None: ) -> None:

View File

@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ClassVar, Generic, TypeVar, cast from typing import Callable, ClassVar, Generic, TypeVar, cast
from rich.console import RenderableType from rich.console import RenderableType
from rich.padding import Padding from rich.padding import Padding
from rich.protocol import is_renderable
from rich.segment import Segment from rich.segment import Segment
from rich.style import Style from rich.style import Style
from rich.text import Text, TextType from rich.text import Text, TextType
@@ -17,9 +18,20 @@ from ..reactive import Reactive
from .._profile import timer from .._profile import timer
from ..scroll_view import ScrollView from ..scroll_view import ScrollView
from ..widget import Widget from ..widget import Widget
from .. import messages
CellType = TypeVar("CellType") CellType = TypeVar("CellType")
CellFormatter = Callable[[object], RenderableType | None]
def default_cell_formatter(obj: object) -> RenderableType | None:
if isinstance(obj, str):
return Text.from_markup(obj)
if not is_renderable(obj):
raise TypeError("Table cell contains {obj!r} which is not renderable")
return cast(RenderableType, obj)
@dataclass @dataclass
class Column: class Column:
@@ -103,14 +115,36 @@ class DataTable(ScrollView, Generic[CellType]):
self.line_contents: list[str] = [] self.line_contents: list[str] = []
self._cells: dict[int, list[Cell]] = {} self._row_render_cache: LRUCache[int, tuple[Lines, Lines]] = LRUCache(1000)
self._cell_render_cache: dict[tuple[int, int], Lines] = LRUCache(10000) self._cell_render_cache: LRUCache[tuple[int, int, Style], Lines] = LRUCache(
10000
)
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] = LRUCache(
1000
)
show_header = Reactive(True) show_header = Reactive(True)
fixed_rows = Reactive(1) fixed_rows = Reactive(1)
fixed_columns = Reactive(1) fixed_columns = Reactive(1)
zebra_stripes = Reactive(False) zebra_stripes = Reactive(False)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
self._cell_render_cache.clear()
self._line_cache.clear()
async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
self._clear_caches()
def watch_show_header(self, show_header: bool) -> None:
self._clear_caches()
def watch_fixed_rows(self, fixed_rows: int) -> None:
self._clear_caches()
def watch_zebra_stripes(self, zebra_stripes: int) -> None:
self._clear_caches()
def _update_dimensions(self) -> None: def _update_dimensions(self) -> None:
max_width = sum(column.width for column in self.columns) max_width = sum(column.width for column in self.columns)
self.virtual_size = Size(max_width, len(self.data) + self.show_header) self.virtual_size = Size(max_width, len(self.data) + self.show_header)
@@ -124,93 +158,103 @@ class DataTable(ScrollView, Generic[CellType]):
def add_row(self, *cells: CellType, height: int = 1) -> None: def add_row(self, *cells: CellType, height: int = 1) -> None:
row_index = self.row_count row_index = self.row_count
self.data[row_index] = list(cells) self.data[row_index] = list(cells)
self.rows[row_index] = Row( self.rows[row_index] = Row(row_index, height=height)
row_index,
height=height,
cell_renderables=[
Text.from_markup(cell) if isinstance(cell, str) else cell
for cell in cells
],
)
self.row_count += 1 self.row_count += 1
self._update_dimensions() self._update_dimensions()
self.refresh() self.refresh()
def get_row(self, y: int) -> list[CellType | Text]: def get_row(self, row_index: int) -> list[RenderableType]:
if y == 0 and self.show_header: if row_index == 0 and self.show_header:
row = [column.label for column in self.columns] row = [column.label for column in self.columns]
return row return row
data_offset = y - 1 if self.show_header else 0 data_offset = row_index - 1 if self.show_header else 0
data = self.data.get(data_offset) data = self.data.get(data_offset)
empty = Text()
if data is None: if data is None:
return [Text() for column in self.columns] return [empty for column in self.columns]
else: else:
return self.rows[data_offset].cell_renderables return [default_cell_formatter(datum) or empty for datum in data]
def _render_cell(self, y: int, column: Column) -> Lines: def _render_cell(self, row_index: int, column: Column, style: Style) -> Lines:
style = Style.from_meta({"y": y, "column": column.index}) cell_key = (row_index, column.index, style)
cell_key = (y, column.index)
if cell_key not in self._cell_render_cache: if cell_key not in self._cell_render_cache:
cell = self.get_row(y)[column.index] style += Style.from_meta({"row": row_index, "column": column.index})
cell = self.get_row(row_index)[column.index]
lines = self.app.console.render_lines( lines = self.app.console.render_lines(
Padding(cell, (0, 1)), Padding(cell, (0, 1)),
self.app.console.options.update_dimensions(column.width, 1), self.app.console.options.update_dimensions(column.width, 1),
style=style, style=style,
) )
self._cell_render_cache[cell_key] = lines self._cell_render_cache[cell_key] = lines
return self._cell_render_cache[cell_key] return self._cell_render_cache[cell_key]
def _render_line(self, y: int, x1: int, x2: int) -> list[Segment]: def _render_row(self, row_index: int, base_style: Style) -> tuple[Lines, Lines]:
if row_index in self._row_render_cache:
return self._row_render_cache[row_index]
width = self.content_region.width if self.fixed_columns:
fixed_style = self.component_styles["datatable--fixed"].node.rich_style
fixed_style += Style.from_meta({"fixed": True})
cell_segments: list[list[Segment]] = [] fixed_row = [
rendered_width = 0 self._render_cell(row_index, column, fixed_style)[0]
for column in self.columns: for column in self.columns[: self.fixed_columns]
lines = self._render_cell(y, column) ]
rendered_width += column.width else:
cell_segments.append(lines[0]) fixed_row = []
base_style = self.rich_style if row_index == 0 and self.show_header:
row_style = self.component_styles["datatable--header"].node.rich_style
fixed_style = self.component_styles[
"datatable--fixed"
].node.rich_style + Style.from_meta({"fixed": True})
header_style = self.component_styles[
"datatable--header"
].node.rich_style + Style.from_meta({"header": True})
fixed: list[Segment] = sum(cell_segments[: self.fixed_columns], start=[])
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed = list(Segment.apply_style(fixed, fixed_style))
line: list[Segment] = sum(cell_segments, start=[])
row_style = base_style
if y == 0:
segments = fixed + line_crop(line, x1 + fixed_width, x2, width)
line = Segment.adjust_line_length(segments, width)
else: else:
if self.zebra_stripes: if self.zebra_stripes:
component_row_style = ( component_row_style = (
"datatable--odd-row" if y % 2 else "datatable--even-row" "datatable--odd-row" if row_index % 2 else "datatable--even-row"
) )
row_style = self.component_styles[component_row_style].node.rich_style row_style = self.component_styles[component_row_style].node.rich_style
else:
row_style = base_style
line = list(Segment.apply_style(line, row_style)) scrollable_row = [
segments = fixed + line_crop(line, x1 + fixed_width, x2, width) self._render_cell(row_index, column, row_style)[0]
line = Segment.adjust_line_length(segments, width, style=base_style) for column in self.columns
]
if y == 0 and self.show_header: row_pair = (fixed_row, scrollable_row)
line = list(Segment.apply_style(line, header_style)) self._row_render_cache[row_index] = row_pair
return row_pair
return line def _render_line(
self, y: int, x1: int, x2: int, base_style: Style
) -> list[Segment]:
width = self.content_region.width
cache_key = (y, x1, x2, width)
if cache_key in self._line_cache:
return self._line_cache[cache_key]
row_index = y
fixed, scrollable = self._render_row(row_index, base_style)
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed_line: list[Segment] = sum(fixed, start=[])
scrollable_line: list[Segment] = sum(scrollable, start=[])
segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width)
# line = Segment.adjust_line_length(segments, width, style=base_style)
remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width)))
if remaining_width > 0:
segments.append(Segment(" " * remaining_width, base_style))
elif remaining_width < 0:
segments = Segment.adjust_line_length(segments, width, style=base_style)
self._line_cache[cache_key] = segments
return segments
@timer("render_lines") @timer("render_lines")
def render_lines(self, crop: Region) -> Lines: def render_lines(self, crop: Region) -> Lines:
@@ -218,8 +262,12 @@ class DataTable(ScrollView, Generic[CellType]):
scroll_x, scroll_y = self.scroll_offset scroll_x, scroll_y = self.scroll_offset
x1, y1, x2, y2 = crop.translate(scroll_x, scroll_y).corners x1, y1, x2, y2 = crop.translate(scroll_x, scroll_y).corners
fixed_lines = [self._render_line(y, x1, x2) for y in range(0, self.fixed_rows)] base_style = self.rich_style
lines = [self._render_line(y, x1, x2) for y in range(y1, y2)]
fixed_lines = [
self._render_line(y, x1, x2, base_style) for y in range(0, self.fixed_rows)
]
lines = [self._render_line(y, x1, x2, base_style) for y in range(y1, y2)]
for fixed_line, y in zip(fixed_lines, range(y1, y2)): for fixed_line, y in zip(fixed_lines, range(y1, y2)):
if y - scroll_y == 0: if y - scroll_y == 0: