From 5b14f8996dfe1a4670366154527aa4e5ad9ed68a Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Mon, 23 Jan 2023 15:37:38 +0000 Subject: [PATCH] Rows are internally tracked using RowKey in DataTable --- .../{_two_way_mapping.py => _two_way_dict.py} | 2 +- src/textual/widgets/_data_table.py | 59 ++++++++++++------- ...wo_way_mapping.py => test_two_way_dict.py} | 4 +- 3 files changed, 42 insertions(+), 23 deletions(-) rename src/textual/{_two_way_mapping.py => _two_way_dict.py} (97%) rename tests/{test_two_way_mapping.py => test_two_way_dict.py} (91%) diff --git a/src/textual/_two_way_mapping.py b/src/textual/_two_way_dict.py similarity index 97% rename from src/textual/_two_way_mapping.py rename to src/textual/_two_way_dict.py index f3f04f369..00fffb8d4 100644 --- a/src/textual/_two_way_mapping.py +++ b/src/textual/_two_way_dict.py @@ -6,7 +6,7 @@ Key = TypeVar("Key") Value = TypeVar("Value") -class TwoWayMapping(Generic[Key, Value]): +class TwoWayDict(Generic[Key, Value]): """ Wraps two dictionaries and uses them to provide efficient access to both values (given keys) and keys (given values). diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index bb6665f83..57a09f2e8 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -15,7 +15,7 @@ from rich.text import Text, TextType from .. import events, messages from .._cache import LRUCache from .._segment_tools import line_crop -from .._two_way_mapping import TwoWayMapping +from .._two_way_dict import TwoWayDict from .._types import SegmentLines from .._typing import Literal from ..binding import Binding @@ -215,14 +215,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): super().__init__(name=name, id=id, classes=classes) self.columns: list[Column] = [] - self.rows: dict[int, Row] = {} + self.rows: dict[RowKey, Row] = {} self.data: dict[int, list[CellType]] = {} self.row_count = 0 # Keep tracking of key -> index for rows/cols. # For a given key, what is the current location of the corresponding row/col? - self._column_locations: TwoWayMapping[ColumnKey, int] = TwoWayMapping({}) - self._row_locations: TwoWayMapping[RowKey, int] = TwoWayMapping({}) + self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) # Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs # TODO: Update types @@ -240,6 +240,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._line_no = 0 self._require_update_dimensions: bool = False + + # TODO: Check what this is used for and if it needs updated to use keys self._new_rows: set[int] = set() self.show_header = show_header @@ -291,10 +293,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._line_cache.clear() self._styles_cache.clear() - def get_row_height(self, row_index: int) -> int: - if row_index == -1: + def get_row_height(self, row_key: int | RowKey) -> int: + # TODO: Update to generate header key ourselves instead of -1, + # and remember to update type signature + if row_key == -1: return self.header_height - return self.rows[row_index].height + return self.rows[row_key].height async def on_styles_updated(self, message: messages.StylesUpdated) -> None: self._clear_caches() @@ -422,9 +426,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_cell_region(self, row_index: int, column_index: int) -> Region: """Get the region of the cell at the given coordinate (row_index, column_index)""" - if row_index not in self.rows: + # This IS used to get the cell region under given a cursor coordinate. + # So we don't want to change this to the key approach, but of course we + # need to look up the row_key first now before proceeding. + # TODO: This is pre-existing method, we'll simply map the indices + # over to the row_keys for now, and likely provide a new means of + row_key = self._row_locations.get_key(row_index) + if row_key not in self.rows: return Region(0, 0, 0, 0) - row = self.rows[row_index] + row = self.rows[row_key] x = sum(column.render_width for column in self.columns[:column_index]) width = self.columns[column_index].render_width height = row.height @@ -439,7 +449,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): rows = self.rows if row_index < 0 or row_index >= len(rows): return Region(0, 0, 0, 0) - row = rows[row_index] + + row_key = self._row_locations.get_key(row_index) + row = rows[row_key] row_width = sum(column.render_width for column in self.columns) y = row.y if self.show_header: @@ -544,7 +556,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_locations[row_key] = row_index self.data[row_index] = list(cells) - self.rows[row_index] = Row(row_key, row_index, height, self._line_no) + self.rows[row_key] = Row(row_key, row_index, height, self._line_no) for line_no in range(height): self._y_offsets.append((row_index, line_no)) @@ -722,12 +734,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if is_fixed_style: style += self.get_component_styles("datatable--cursor-fixed").rich_style - cell_key = (row_index, column_index, style, cursor, hover) + # TODO: We can hoist `row_key` lookup waaay up to do it inside `_get_offsets` + # then just pass it through to here instead of the row_index. + row_key = self._row_locations.get_key(row_index) + cell_key = (row_key, column_index, style, cursor, hover) if cell_key not in self._cell_render_cache: style += Style.from_meta({"row": row_index, "column": column_index}) - height = ( - self.header_height if is_header_row else self.rows[row_index].height - ) + height = self.header_height if is_header_row else self.rows[row_key].height cell = self._get_row_renderables(row_index)[column_index] lines = self.app.console.render_lines( Padding(cell, (0, 1)), @@ -920,13 +933,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def render_line(self, y: int) -> Strip: width, height = self.size scroll_x, scroll_y = self.scroll_offset - fixed_top_row_count = sum( - self.get_row_height(row_index) for row_index in range(self.fixed_rows) + + fixed_row_keys: list[RowKey] = [ + self._row_locations.get_key(row_index) + for row_index in range(self.fixed_rows) + ] + + fixed_rows_height = sum( + self.get_row_height(row_key) for row_key in fixed_row_keys ) if self.show_header: - fixed_top_row_count += self.get_row_height(-1) + fixed_rows_height += self.get_row_height(-1) - if y >= fixed_top_row_count: + if y >= fixed_rows_height: y += scroll_y return self._render_line(y, scroll_x, scroll_x + width, self.rich_style) @@ -943,7 +962,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_fixed_offset(self) -> Spacing: top = self.header_height if self.show_header else 0 top += sum( - self.rows[row_index].height + self.rows[self._row_locations.get_key(row_index)].height for row_index in range(self.fixed_rows) if row_index in self.rows ) diff --git a/tests/test_two_way_mapping.py b/tests/test_two_way_dict.py similarity index 91% rename from tests/test_two_way_mapping.py rename to tests/test_two_way_dict.py index 9a88bdeec..26e1cb58e 100644 --- a/tests/test_two_way_mapping.py +++ b/tests/test_two_way_dict.py @@ -1,11 +1,11 @@ import pytest -from textual._two_way_mapping import TwoWayMapping +from textual._two_way_dict import TwoWayDict @pytest.fixture def map(): - return TwoWayMapping( + return TwoWayDict( { 1: 10, 2: 20,