variable sized rows

This commit is contained in:
Will McGugan
2022-06-19 13:23:35 +01:00
parent bb1449315e
commit 5b56bc116a
4 changed files with 78 additions and 40 deletions

View File

@@ -54,6 +54,7 @@ class LRUCache(Generic[CacheKey, CacheValue]):
"""Clear the cache."""
with self._lock:
self.cache.clear()
self.full = False
self.root = []
def set(self, key: CacheKey, value: CacheValue) -> None:

View File

@@ -13,6 +13,7 @@ without having to render the entire screen.
from __future__ import annotations
from itertools import chain
from operator import attrgetter, itemgetter
import sys
from typing import Callable, cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING
@@ -566,9 +567,10 @@ class Compositor:
) -> list[list[Segment]]:
"""Combine chops in to lines."""
segment_lines: list[list[Segment]] = [
sum(
[line for line in bucket.values() if line is not None],
[],
list(
chain.from_iterable(
line for line in bucket.values() if line is not None
)
)
for bucket in chops
]

View File

@@ -522,10 +522,10 @@ class Widget(DOMNode):
)
def scroll_home(self, *, animate: bool = True) -> bool:
return self.scroll_to(0, 0, animate=animate)
return self.scroll_to(0, 0, animate=animate, duration=1)
def scroll_end(self, *, animate: bool = True) -> bool:
return self.scroll_to(0, self.max_scroll_y, animate=animate)
return self.scroll_to(0, self.max_scroll_y, animate=animate, duration=1)
def scroll_left(self, *, animate: bool = True) -> bool:
return self.scroll_to(x=self.scroll_target_x - 1, animate=animate)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from itertools import chain
from typing import Callable, ClassVar, Generic, TypeVar, cast
from rich.console import RenderableType
@@ -44,7 +45,7 @@ class Column:
@dataclass
class Row:
index: int
height: int = 1
height: int
cell_renderables: list[RenderableType] = field(default_factory=list)
@@ -113,26 +114,33 @@ class DataTable(ScrollView, Generic[CellType]):
self.data: dict[int, list[CellType]] = {}
self.row_count = 0
self.line_contents: list[str] = []
self._y_offsets: list[tuple[int, int]] = []
self._row_render_cache: LRUCache[int, tuple[Lines, Lines]] = LRUCache(1000)
self._cell_render_cache: LRUCache[tuple[int, int, Style], Lines] = LRUCache(
10000
)
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] = LRUCache(
1000
)
self._row_render_cache: LRUCache[tuple[int, int, Style], tuple[Lines, Lines]]
self._row_render_cache = LRUCache(1000)
self._cell_render_cache: LRUCache[tuple[int, int, Style], Lines]
self._cell_render_cache = LRUCache(10000)
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]]
self._line_cache = LRUCache(1000)
show_header = Reactive(True)
fixed_rows = Reactive(1)
fixed_rows = Reactive(0)
fixed_columns = Reactive(1)
zebra_stripes = Reactive(False)
header_height = Reactive(1)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
self._cell_render_cache.clear()
self._line_cache.clear()
def get_row_height(self, row_index: int) -> int:
if row_index == -1:
return self.header_height
return self.rows[row_index].height
async def handle_styles_updated(self, message: messages.StylesUpdated) -> None:
self._clear_caches()
@@ -147,7 +155,10 @@ class DataTable(ScrollView, Generic[CellType]):
def _update_dimensions(self) -> None:
max_width = sum(column.width for column in self.columns)
self.virtual_size = Size(max_width, len(self.data) + self.show_header)
self.virtual_size = Size(
max_width,
len(self._y_offsets) + (self.header_height if self.show_header else 0),
)
def add_column(self, label: TextType, *, width: int = 10) -> None:
text_label = Text.from_markup(label) if isinstance(label, str) else label
@@ -155,25 +166,28 @@ class DataTable(ScrollView, Generic[CellType]):
self._update_dimensions()
self.refresh()
def add_row(self, *cells: CellType, height: int = 1) -> None:
def add_row(self, *cells: CellType, height: int = 3) -> None:
row_index = self.row_count
self.data[row_index] = list(cells)
self.rows[row_index] = Row(row_index, height=height)
for line_no in range(height):
self._y_offsets.append((row_index, line_no))
self.row_count += 1
self._update_dimensions()
self.refresh()
def get_row(self, row_index: int) -> list[RenderableType]:
def get_row_renderables(self, row_index: int) -> list[RenderableType]:
if row_index == 0 and self.show_header:
if row_index == -1:
row = [column.label for column in self.columns]
return row
data_offset = row_index - 1 if self.show_header else 0
data = self.data.get(data_offset)
data = self.data.get(row_index)
empty = Text()
if data is None:
return [empty for column in self.columns]
return [Text("!") for column in self.columns]
else:
return [default_cell_formatter(datum) or empty for datum in data]
@@ -182,31 +196,38 @@ class DataTable(ScrollView, Generic[CellType]):
cell_key = (row_index, column.index, style)
if cell_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column.index})
cell = self.get_row(row_index)[column.index]
height = (
self.header_height if row_index == -1 else self.rows[row_index].height
)
cell = self.get_row_renderables(row_index)[column.index]
lines = self.app.console.render_lines(
Padding(cell, (0, 1)),
self.app.console.options.update_dimensions(column.width, 1),
self.app.console.options.update_dimensions(column.width, height),
style=style,
)
self._cell_render_cache[cell_key] = lines
return self._cell_render_cache[cell_key]
def _render_row(self, row_index: int, base_style: Style) -> tuple[Lines, Lines]:
if row_index in self._row_render_cache:
return self._row_render_cache[row_index]
def _render_row(
self, row_index: int, line_no: int, base_style: Style
) -> tuple[Lines, Lines]:
cache_key = (row_index, line_no, base_style)
if cache_key in self._row_render_cache:
return self._row_render_cache[cache_key]
if self.fixed_columns:
fixed_style = self.component_styles["datatable--fixed"].node.rich_style
fixed_style += Style.from_meta({"fixed": True})
fixed_row = [
self._render_cell(row_index, column, fixed_style)[0]
self._render_cell(row_index, column, fixed_style)[line_no]
for column in self.columns[: self.fixed_columns]
]
else:
fixed_row = []
if row_index == 0 and self.show_header:
if row_index == -1:
row_style = self.component_styles["datatable--header"].node.rich_style
else:
if self.zebra_stripes:
@@ -218,14 +239,21 @@ class DataTable(ScrollView, Generic[CellType]):
row_style = base_style
scrollable_row = [
self._render_cell(row_index, column, row_style)[0]
self._render_cell(row_index, column, row_style)[line_no]
for column in self.columns
]
row_pair = (fixed_row, scrollable_row)
self._row_render_cache[row_index] = row_pair
self._row_render_cache[cache_key] = row_pair
return row_pair
def _get_offsets(self, y: int) -> tuple[int, int]:
if self.show_header:
if y < self.header_height:
return (-1, y)
y -= self.header_height
return self._y_offsets[y]
def _render_line(
self, y: int, x1: int, x2: int, base_style: Style
) -> list[Segment]:
@@ -236,13 +264,13 @@ class DataTable(ScrollView, Generic[CellType]):
if cache_key in self._line_cache:
return self._line_cache[cache_key]
row_index = y
row_index, line_no = self._get_offsets(y)
fixed, scrollable = self._render_row(row_index, base_style)
fixed, scrollable = self._render_row(row_index, line_no, base_style)
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed_line: list[Segment] = sum(fixed, start=[])
scrollable_line: list[Segment] = sum(scrollable, start=[])
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
scrollable_line: list[Segment] = list(chain.from_iterable(scrollable))
segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width)
@@ -264,14 +292,21 @@ class DataTable(ScrollView, Generic[CellType]):
base_style = self.rich_style
fixed_top_row_count = sum(
self.get_row_height(row_index) for row_index in range(self.fixed_rows)
)
if self.show_header:
fixed_top_row_count += self.get_row_height(-1)
fixed_lines = [
self._render_line(y, x1, x2, base_style) for y in range(0, self.fixed_rows)
self._render_line(y, x1, x2, base_style)
for y in range(0, fixed_top_row_count)
]
lines = [self._render_line(y, x1, x2, base_style) for y in range(y1, y2)]
for fixed_line, y in zip(fixed_lines, range(y1, y2)):
if y - scroll_y == 0:
lines[0] = fixed_line
for line_index, y in enumerate(range(y1, y2)):
if y - scroll_y < fixed_top_row_count:
lines[line_index] = fixed_lines[line_index]
return lines