render cache

This commit is contained in:
Will McGugan
2022-06-18 18:38:50 +01:00
parent f4c12704fe
commit bb1449315e
10 changed files with 191 additions and 88 deletions

View File

@@ -105,11 +105,10 @@ Tweet {
.code {
height: auto;
}
}
TweetHeader {
height:1;

View File

@@ -1,13 +1,23 @@
"""
LRU Cache operation borrowed from Rich.
This may become more sophisticated in Textual, but hopefully remain simple in Rich.
"""
import sys
from collections import deque
from functools import wraps
from threading import Lock
from typing import Dict, Generic, List, Optional, TypeVar, Union, overload
from typing import (
Callable,
Deque,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
overload,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
CacheKey = TypeVar("CacheKey")
CacheValue = TypeVar("CacheValue")
@@ -40,6 +50,12 @@ class LRUCache(Generic[CacheKey, CacheValue]):
def __len__(self) -> int:
return len(self.cache)
def clear(self) -> None:
"""Clear the cache."""
with self._lock:
self.cache.clear()
self.root = []
def set(self, key: CacheKey, value: CacheValue) -> None:
"""Set a value.
@@ -122,3 +138,48 @@ class LRUCache(Generic[CacheKey, CacheValue]):
def __contains__(self, key: CacheKey) -> bool:
return key in self.cache
P = ParamSpec("P")
T = TypeVar("T")
def fifo_cache(maxsize: int) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""A First In First Out cache.
Args:
maxsize (int): Maximum size of the cache
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
queue: Deque[object] = deque()
cache: Dict[object, T] = {}
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return cache[args]
except KeyError:
assert not kwargs, "Will not work with keyword arguments!"
cache[args] = result = func(*args)
queue.append(args)
if len(queue) > maxsize:
del cache[queue.popleft()]
return result
return wrapper
return decorator
@fifo_cache(10)
def double(n: int) -> int:
return n * n
print(double(1))
print(double(2))
print(double(2))
print(double(3))
print(double(4))

View File

@@ -445,7 +445,6 @@ class Compositor:
x -= region.x
y -= region.y
# lines = widget.render_lines((y, y + 1), (0, region.width))
lines = widget.render_lines(Region(0, y, region.width, 1))
if not lines:
@@ -575,6 +574,7 @@ class Compositor:
]
return segment_lines
@timer("render")
def render(self, full: bool = False) -> RenderableType | None:
"""Render a layout.

View File

@@ -194,7 +194,7 @@ class App(Generic[ReturnType], DOMNode):
self.design = DEFAULT_COLORS
self.stylesheet = Stylesheet(variables=self.get_css_variables())
self._require_styles_update = False
self._require_stylesheet_update = False
self.css_path = css_path or self.CSS_PATH
self.registry: set[MessagePump] = set()
@@ -584,7 +584,7 @@ class App(Generic[ReturnType], DOMNode):
Should be called whenever CSS classes / pseudo classes change.
"""
self._require_styles_update = True
self._require_stylesheet_update = True
self.check_idle()
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None:
@@ -817,9 +817,9 @@ class App(Generic[ReturnType], DOMNode):
async def on_idle(self) -> None:
"""Perform actions when there are no messages in the queue."""
if self._require_styles_update:
await self.post_message(messages.StylesUpdated(self))
self._require_styles_update = False
if self._require_stylesheet_update:
self._require_stylesheet_update = False
self.stylesheet.update(self, animate=True)
def _register_child(self, parent: DOMNode, child: DOMNode) -> bool:
if child not in self.registry:
@@ -1135,8 +1135,8 @@ class App(Generic[ReturnType], DOMNode):
async def action_toggle_class(self, selector: str, class_name: str) -> None:
self.screen.query(selector).toggle_class(class_name)
async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
self.stylesheet.update(self, animate=True)
# async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
# self.stylesheet.update(self, animate=True)
def handle_terminal_supports_synchronized_output(
self, message: messages.TerminalSupportsSynchronizedOutput

View File

@@ -25,6 +25,7 @@ from .tokenize import tokenize_values, Token
from .tokenizer import TokenizeError
from .types import Specificity3, Specificity4
from ..dom import DOMNode
from .. import messages
class StylesheetParseError(StylesheetError):
@@ -375,6 +376,8 @@ class Stylesheet:
for key in modified_rule_keys:
setattr(base_styles, key, get_rule(key))
node.post_message_no_wait(messages.StylesUpdated(sender=node))
def update(self, root: DOMNode, animate: bool = False) -> None:
"""Update a node and its children."""
apply = self.apply

View File

@@ -428,9 +428,6 @@ class DOMNode(MessagePump):
node.set_dirty()
node._layout_required = True
def on_style_change(self) -> None:
pass
def add_child(self, node: DOMNode) -> None:
"""Add a new child node.

View File

@@ -7,7 +7,6 @@ Functions and classes to manage terminal geometry (anything involving coordinate
from __future__ import annotations
from functools import lru_cache
from typing import Any, cast, Collection, NamedTuple, Tuple, TypeAlias, Union, TypeVar
SpacingDimensions: TypeAlias = Union[

View File

@@ -57,7 +57,7 @@ class Prompt(Message, system=True):
"""Used to 'wake up' an event loop."""
def can_replace(self, message: Message) -> bool:
return isinstance(message, StylesUpdated)
return isinstance(message, Prompt)
class TerminalSupportsSynchronizedOutput(Message):

View File

@@ -854,10 +854,6 @@ class Widget(DOMNode):
"""Update from CSS if has focus state changes."""
self.app.update_styles()
def on_style_change(self) -> None:
self.set_dirty()
self.check_idle()
def size_updated(
self, size: Size, virtual_size: Size, container_size: Size
) -> None:

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import ClassVar, Generic, TypeVar, cast
from typing import Callable, ClassVar, Generic, TypeVar, cast
from rich.console import RenderableType
from rich.padding import Padding
from rich.protocol import is_renderable
from rich.segment import Segment
from rich.style import Style
from rich.text import Text, TextType
@@ -17,9 +18,20 @@ from ..reactive import Reactive
from .._profile import timer
from ..scroll_view import ScrollView
from ..widget import Widget
from .. import messages
CellType = TypeVar("CellType")
CellFormatter = Callable[[object], RenderableType | None]
def default_cell_formatter(obj: object) -> RenderableType | None:
if isinstance(obj, str):
return Text.from_markup(obj)
if not is_renderable(obj):
raise TypeError("Table cell contains {obj!r} which is not renderable")
return cast(RenderableType, obj)
@dataclass
class Column:
@@ -103,14 +115,36 @@ class DataTable(ScrollView, Generic[CellType]):
self.line_contents: list[str] = []
self._cells: dict[int, list[Cell]] = {}
self._cell_render_cache: dict[tuple[int, int], Lines] = LRUCache(10000)
self._row_render_cache: LRUCache[int, tuple[Lines, Lines]] = LRUCache(1000)
self._cell_render_cache: LRUCache[tuple[int, int, Style], Lines] = LRUCache(
10000
)
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] = LRUCache(
1000
)
show_header = Reactive(True)
fixed_rows = Reactive(1)
fixed_columns = Reactive(1)
zebra_stripes = Reactive(False)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
self._cell_render_cache.clear()
self._line_cache.clear()
async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
self._clear_caches()
def watch_show_header(self, show_header: bool) -> None:
self._clear_caches()
def watch_fixed_rows(self, fixed_rows: int) -> None:
self._clear_caches()
def watch_zebra_stripes(self, zebra_stripes: int) -> None:
self._clear_caches()
def _update_dimensions(self) -> None:
max_width = sum(column.width for column in self.columns)
self.virtual_size = Size(max_width, len(self.data) + self.show_header)
@@ -124,93 +158,103 @@ class DataTable(ScrollView, Generic[CellType]):
def add_row(self, *cells: CellType, height: int = 1) -> None:
row_index = self.row_count
self.data[row_index] = list(cells)
self.rows[row_index] = Row(
row_index,
height=height,
cell_renderables=[
Text.from_markup(cell) if isinstance(cell, str) else cell
for cell in cells
],
)
self.rows[row_index] = Row(row_index, height=height)
self.row_count += 1
self._update_dimensions()
self.refresh()
def get_row(self, y: int) -> list[CellType | Text]:
def get_row(self, row_index: int) -> list[RenderableType]:
if y == 0 and self.show_header:
if row_index == 0 and self.show_header:
row = [column.label for column in self.columns]
return row
data_offset = y - 1 if self.show_header else 0
data_offset = row_index - 1 if self.show_header else 0
data = self.data.get(data_offset)
empty = Text()
if data is None:
return [Text() for column in self.columns]
return [empty for column in self.columns]
else:
return self.rows[data_offset].cell_renderables
return [default_cell_formatter(datum) or empty for datum in data]
def _render_cell(self, y: int, column: Column) -> Lines:
def _render_cell(self, row_index: int, column: Column, style: Style) -> Lines:
style = Style.from_meta({"y": y, "column": column.index})
cell_key = (y, column.index)
cell_key = (row_index, column.index, style)
if cell_key not in self._cell_render_cache:
cell = self.get_row(y)[column.index]
style += Style.from_meta({"row": row_index, "column": column.index})
cell = self.get_row(row_index)[column.index]
lines = self.app.console.render_lines(
Padding(cell, (0, 1)),
self.app.console.options.update_dimensions(column.width, 1),
style=style,
)
self._cell_render_cache[cell_key] = lines
return self._cell_render_cache[cell_key]
def _render_line(self, y: int, x1: int, x2: int) -> list[Segment]:
def _render_row(self, row_index: int, base_style: Style) -> tuple[Lines, Lines]:
if row_index in self._row_render_cache:
return self._row_render_cache[row_index]
width = self.content_region.width
if self.fixed_columns:
fixed_style = self.component_styles["datatable--fixed"].node.rich_style
fixed_style += Style.from_meta({"fixed": True})
cell_segments: list[list[Segment]] = []
rendered_width = 0
for column in self.columns:
lines = self._render_cell(y, column)
rendered_width += column.width
cell_segments.append(lines[0])
fixed_row = [
self._render_cell(row_index, column, fixed_style)[0]
for column in self.columns[: self.fixed_columns]
]
else:
fixed_row = []
base_style = self.rich_style
fixed_style = self.component_styles[
"datatable--fixed"
].node.rich_style + Style.from_meta({"fixed": True})
header_style = self.component_styles[
"datatable--header"
].node.rich_style + Style.from_meta({"header": True})
fixed: list[Segment] = sum(cell_segments[: self.fixed_columns], start=[])
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed = list(Segment.apply_style(fixed, fixed_style))
line: list[Segment] = sum(cell_segments, start=[])
row_style = base_style
if y == 0:
segments = fixed + line_crop(line, x1 + fixed_width, x2, width)
line = Segment.adjust_line_length(segments, width)
if row_index == 0 and self.show_header:
row_style = self.component_styles["datatable--header"].node.rich_style
else:
if self.zebra_stripes:
component_row_style = (
"datatable--odd-row" if y % 2 else "datatable--even-row"
"datatable--odd-row" if row_index % 2 else "datatable--even-row"
)
row_style = self.component_styles[component_row_style].node.rich_style
else:
row_style = base_style
line = list(Segment.apply_style(line, row_style))
segments = fixed + line_crop(line, x1 + fixed_width, x2, width)
line = Segment.adjust_line_length(segments, width, style=base_style)
scrollable_row = [
self._render_cell(row_index, column, row_style)[0]
for column in self.columns
]
if y == 0 and self.show_header:
line = list(Segment.apply_style(line, header_style))
row_pair = (fixed_row, scrollable_row)
self._row_render_cache[row_index] = row_pair
return row_pair
return line
def _render_line(
self, y: int, x1: int, x2: int, base_style: Style
) -> list[Segment]:
width = self.content_region.width
cache_key = (y, x1, x2, width)
if cache_key in self._line_cache:
return self._line_cache[cache_key]
row_index = y
fixed, scrollable = self._render_row(row_index, base_style)
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed_line: list[Segment] = sum(fixed, start=[])
scrollable_line: list[Segment] = sum(scrollable, start=[])
segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width)
# line = Segment.adjust_line_length(segments, width, style=base_style)
remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width)))
if remaining_width > 0:
segments.append(Segment(" " * remaining_width, base_style))
elif remaining_width < 0:
segments = Segment.adjust_line_length(segments, width, style=base_style)
self._line_cache[cache_key] = segments
return segments
@timer("render_lines")
def render_lines(self, crop: Region) -> Lines:
@@ -218,8 +262,12 @@ class DataTable(ScrollView, Generic[CellType]):
scroll_x, scroll_y = self.scroll_offset
x1, y1, x2, y2 = crop.translate(scroll_x, scroll_y).corners
fixed_lines = [self._render_line(y, x1, x2) for y in range(0, self.fixed_rows)]
lines = [self._render_line(y, x1, x2) for y in range(y1, y2)]
base_style = self.rich_style
fixed_lines = [
self._render_line(y, x1, x2, base_style) for y in range(0, self.fixed_rows)
]
lines = [self._render_line(y, x1, x2, base_style) for y in range(y1, y2)]
for fixed_line, y in zip(fixed_lines, range(y1, y2)):
if y - scroll_y == 0: