mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Rows are internally tracked using RowKey in DataTable
This commit is contained in:
@@ -6,7 +6,7 @@ Key = TypeVar("Key")
|
||||
Value = TypeVar("Value")
|
||||
|
||||
|
||||
class TwoWayMapping(Generic[Key, Value]):
|
||||
class TwoWayDict(Generic[Key, Value]):
|
||||
"""
|
||||
Wraps two dictionaries and uses them to provide efficient access to
|
||||
both values (given keys) and keys (given values).
|
||||
@@ -15,7 +15,7 @@ from rich.text import Text, TextType
|
||||
from .. import events, messages
|
||||
from .._cache import LRUCache
|
||||
from .._segment_tools import line_crop
|
||||
from .._two_way_mapping import TwoWayMapping
|
||||
from .._two_way_dict import TwoWayDict
|
||||
from .._types import SegmentLines
|
||||
from .._typing import Literal
|
||||
from ..binding import Binding
|
||||
@@ -215,14 +215,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
|
||||
self.columns: list[Column] = []
|
||||
self.rows: dict[int, Row] = {}
|
||||
self.rows: dict[RowKey, Row] = {}
|
||||
self.data: dict[int, list[CellType]] = {}
|
||||
self.row_count = 0
|
||||
|
||||
# Keep tracking of key -> index for rows/cols.
|
||||
# For a given key, what is the current location of the corresponding row/col?
|
||||
self._column_locations: TwoWayMapping[ColumnKey, int] = TwoWayMapping({})
|
||||
self._row_locations: TwoWayMapping[RowKey, int] = TwoWayMapping({})
|
||||
self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({})
|
||||
self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
|
||||
|
||||
# Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs
|
||||
# TODO: Update types
|
||||
@@ -240,6 +240,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
self._line_no = 0
|
||||
self._require_update_dimensions: bool = False
|
||||
|
||||
# TODO: Check what this is used for and if it needs updated to use keys
|
||||
self._new_rows: set[int] = set()
|
||||
|
||||
self.show_header = show_header
|
||||
@@ -291,10 +293,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._line_cache.clear()
|
||||
self._styles_cache.clear()
|
||||
|
||||
def get_row_height(self, row_index: int) -> int:
|
||||
if row_index == -1:
|
||||
def get_row_height(self, row_key: int | RowKey) -> int:
|
||||
# TODO: Update to generate header key ourselves instead of -1,
|
||||
# and remember to update type signature
|
||||
if row_key == -1:
|
||||
return self.header_height
|
||||
return self.rows[row_index].height
|
||||
return self.rows[row_key].height
|
||||
|
||||
async def on_styles_updated(self, message: messages.StylesUpdated) -> None:
|
||||
self._clear_caches()
|
||||
@@ -422,9 +426,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
|
||||
"""Get the region of the cell at the given coordinate (row_index, column_index)"""
|
||||
if row_index not in self.rows:
|
||||
# This IS used to get the cell region under given a cursor coordinate.
|
||||
# So we don't want to change this to the key approach, but of course we
|
||||
# need to look up the row_key first now before proceeding.
|
||||
# TODO: This is pre-existing method, we'll simply map the indices
|
||||
# over to the row_keys for now, and likely provide a new means of
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
if row_key not in self.rows:
|
||||
return Region(0, 0, 0, 0)
|
||||
row = self.rows[row_index]
|
||||
row = self.rows[row_key]
|
||||
x = sum(column.render_width for column in self.columns[:column_index])
|
||||
width = self.columns[column_index].render_width
|
||||
height = row.height
|
||||
@@ -439,7 +449,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
rows = self.rows
|
||||
if row_index < 0 or row_index >= len(rows):
|
||||
return Region(0, 0, 0, 0)
|
||||
row = rows[row_index]
|
||||
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
row = rows[row_key]
|
||||
row_width = sum(column.render_width for column in self.columns)
|
||||
y = row.y
|
||||
if self.show_header:
|
||||
@@ -544,7 +556,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._row_locations[row_key] = row_index
|
||||
|
||||
self.data[row_index] = list(cells)
|
||||
self.rows[row_index] = Row(row_key, row_index, height, self._line_no)
|
||||
self.rows[row_key] = Row(row_key, row_index, height, self._line_no)
|
||||
|
||||
for line_no in range(height):
|
||||
self._y_offsets.append((row_index, line_no))
|
||||
@@ -722,12 +734,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
if is_fixed_style:
|
||||
style += self.get_component_styles("datatable--cursor-fixed").rich_style
|
||||
|
||||
cell_key = (row_index, column_index, style, cursor, hover)
|
||||
# TODO: We can hoist `row_key` lookup waaay up to do it inside `_get_offsets`
|
||||
# then just pass it through to here instead of the row_index.
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
cell_key = (row_key, column_index, style, cursor, hover)
|
||||
if cell_key not in self._cell_render_cache:
|
||||
style += Style.from_meta({"row": row_index, "column": column_index})
|
||||
height = (
|
||||
self.header_height if is_header_row else self.rows[row_index].height
|
||||
)
|
||||
height = self.header_height if is_header_row else self.rows[row_key].height
|
||||
cell = self._get_row_renderables(row_index)[column_index]
|
||||
lines = self.app.console.render_lines(
|
||||
Padding(cell, (0, 1)),
|
||||
@@ -920,13 +933,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
def render_line(self, y: int) -> Strip:
|
||||
width, height = self.size
|
||||
scroll_x, scroll_y = self.scroll_offset
|
||||
fixed_top_row_count = sum(
|
||||
self.get_row_height(row_index) for row_index in range(self.fixed_rows)
|
||||
|
||||
fixed_row_keys: list[RowKey] = [
|
||||
self._row_locations.get_key(row_index)
|
||||
for row_index in range(self.fixed_rows)
|
||||
]
|
||||
|
||||
fixed_rows_height = sum(
|
||||
self.get_row_height(row_key) for row_key in fixed_row_keys
|
||||
)
|
||||
if self.show_header:
|
||||
fixed_top_row_count += self.get_row_height(-1)
|
||||
fixed_rows_height += self.get_row_height(-1)
|
||||
|
||||
if y >= fixed_top_row_count:
|
||||
if y >= fixed_rows_height:
|
||||
y += scroll_y
|
||||
|
||||
return self._render_line(y, scroll_x, scroll_x + width, self.rich_style)
|
||||
@@ -943,7 +962,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
def _get_fixed_offset(self) -> Spacing:
|
||||
top = self.header_height if self.show_header else 0
|
||||
top += sum(
|
||||
self.rows[row_index].height
|
||||
self.rows[self._row_locations.get_key(row_index)].height
|
||||
for row_index in range(self.fixed_rows)
|
||||
if row_index in self.rows
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from textual._two_way_mapping import TwoWayMapping
|
||||
from textual._two_way_dict import TwoWayDict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def map():
|
||||
return TwoWayMapping(
|
||||
return TwoWayDict(
|
||||
{
|
||||
1: 10,
|
||||
2: 20,
|
||||
Reference in New Issue
Block a user