Rows are internally tracked using RowKey in DataTable

This commit is contained in:
Darren Burns
2023-01-23 15:37:38 +00:00
parent f983ac308d
commit 5b14f8996d
3 changed files with 42 additions and 23 deletions

View File

@@ -6,7 +6,7 @@ Key = TypeVar("Key")
Value = TypeVar("Value")
class TwoWayMapping(Generic[Key, Value]):
class TwoWayDict(Generic[Key, Value]):
"""
Wraps two dictionaries and uses them to provide efficient access to
both values (given keys) and keys (given values).

View File

@@ -15,7 +15,7 @@ from rich.text import Text, TextType
from .. import events, messages
from .._cache import LRUCache
from .._segment_tools import line_crop
from .._two_way_mapping import TwoWayMapping
from .._two_way_dict import TwoWayDict
from .._types import SegmentLines
from .._typing import Literal
from ..binding import Binding
@@ -215,14 +215,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
super().__init__(name=name, id=id, classes=classes)
self.columns: list[Column] = []
self.rows: dict[int, Row] = {}
self.rows: dict[RowKey, Row] = {}
self.data: dict[int, list[CellType]] = {}
self.row_count = 0
# Keep tracking of key -> index for rows/cols.
# For a given key, what is the current location of the corresponding row/col?
self._column_locations: TwoWayMapping[ColumnKey, int] = TwoWayMapping({})
self._row_locations: TwoWayMapping[RowKey, int] = TwoWayMapping({})
self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({})
self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
# Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs
# TODO: Update types
@@ -240,6 +240,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._line_no = 0
self._require_update_dimensions: bool = False
# TODO: Check what this is used for and if it needs updated to use keys
self._new_rows: set[int] = set()
self.show_header = show_header
@@ -291,10 +293,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._line_cache.clear()
self._styles_cache.clear()
def get_row_height(self, row_index: int) -> int:
if row_index == -1:
def get_row_height(self, row_key: int | RowKey) -> int:
# TODO: Update to generate header key ourselves instead of -1,
# and remember to update type signature
if row_key == -1:
return self.header_height
return self.rows[row_index].height
return self.rows[row_key].height
async def on_styles_updated(self, message: messages.StylesUpdated) -> None:
self._clear_caches()
@@ -422,9 +426,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
"""Get the region of the cell at the given coordinate (row_index, column_index)"""
if row_index not in self.rows:
# This IS used to get the cell region under given a cursor coordinate.
# So we don't want to change this to the key approach, but of course we
# need to look up the row_key first now before proceeding.
# TODO: This is pre-existing method, we'll simply map the indices
# over to the row_keys for now, and likely provide a new means of
row_key = self._row_locations.get_key(row_index)
if row_key not in self.rows:
return Region(0, 0, 0, 0)
row = self.rows[row_index]
row = self.rows[row_key]
x = sum(column.render_width for column in self.columns[:column_index])
width = self.columns[column_index].render_width
height = row.height
@@ -439,7 +449,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
rows = self.rows
if row_index < 0 or row_index >= len(rows):
return Region(0, 0, 0, 0)
row = rows[row_index]
row_key = self._row_locations.get_key(row_index)
row = rows[row_key]
row_width = sum(column.render_width for column in self.columns)
y = row.y
if self.show_header:
@@ -544,7 +556,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._row_locations[row_key] = row_index
self.data[row_index] = list(cells)
self.rows[row_index] = Row(row_key, row_index, height, self._line_no)
self.rows[row_key] = Row(row_key, row_index, height, self._line_no)
for line_no in range(height):
self._y_offsets.append((row_index, line_no))
@@ -722,12 +734,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
if is_fixed_style:
style += self.get_component_styles("datatable--cursor-fixed").rich_style
cell_key = (row_index, column_index, style, cursor, hover)
# TODO: We can hoist `row_key` lookup waaay up to do it inside `_get_offsets`
# then just pass it through to here instead of the row_index.
row_key = self._row_locations.get_key(row_index)
cell_key = (row_key, column_index, style, cursor, hover)
if cell_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column_index})
height = (
self.header_height if is_header_row else self.rows[row_index].height
)
height = self.header_height if is_header_row else self.rows[row_key].height
cell = self._get_row_renderables(row_index)[column_index]
lines = self.app.console.render_lines(
Padding(cell, (0, 1)),
@@ -920,13 +933,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def render_line(self, y: int) -> Strip:
width, height = self.size
scroll_x, scroll_y = self.scroll_offset
fixed_top_row_count = sum(
self.get_row_height(row_index) for row_index in range(self.fixed_rows)
fixed_row_keys: list[RowKey] = [
self._row_locations.get_key(row_index)
for row_index in range(self.fixed_rows)
]
fixed_rows_height = sum(
self.get_row_height(row_key) for row_key in fixed_row_keys
)
if self.show_header:
fixed_top_row_count += self.get_row_height(-1)
fixed_rows_height += self.get_row_height(-1)
if y >= fixed_top_row_count:
if y >= fixed_rows_height:
y += scroll_y
return self._render_line(y, scroll_x, scroll_x + width, self.rich_style)
@@ -943,7 +962,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_fixed_offset(self) -> Spacing:
top = self.header_height if self.show_header else 0
top += sum(
self.rows[row_index].height
self.rows[self._row_locations.get_key(row_index)].height
for row_index in range(self.fixed_rows)
if row_index in self.rows
)

View File

@@ -1,11 +1,11 @@
import pytest
from textual._two_way_mapping import TwoWayMapping
from textual._two_way_dict import TwoWayDict
@pytest.fixture
def map():
return TwoWayMapping(
return TwoWayDict(
{
1: 10,
2: 20,