Keys for columns in the DataTable

This commit is contained in:
Darren Burns
2023-01-24 14:17:29 +00:00
parent 2d498d516d
commit a958c66671
2 changed files with 76 additions and 52 deletions

View File

@@ -214,7 +214,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
) -> None:
super().__init__(name=name, id=id, classes=classes)
self.columns: list[Column] = []
self.columns: dict[ColumnKey, Column] = {}
self.rows: dict[RowKey, Row] = {}
self.data: dict[RowKey, list[CellType]] = {}
@@ -224,7 +224,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
# Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs
# TODO: Update types
self._y_offsets: list[tuple[RowKey, int]] = []
self._row_render_cache: LRUCache[
tuple[RowKey, int, Style, int, int], tuple[SegmentLines, SegmentLines]
@@ -421,13 +420,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
"""Called to recalculate the virtual (scrollable) size."""
for row_index in new_rows:
for column, renderable in zip(
self.columns, self._get_row_renderables(row_index)
self.columns.values(), self._get_row_renderables(row_index)
):
content_width = measure(self.app.console, renderable, 1)
column.content_width = max(column.content_width, content_width)
self._clear_caches()
total_width = sum(column.render_width for column in self.columns)
total_width = sum(column.render_width for column in self.columns.values())
header_height = self.header_height if self.show_header else 0
self.virtual_size = Size(
total_width,
@@ -435,18 +434,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
)
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
"""Get the region of the cell at the given coordinate (row_index, column_index)"""
# This IS used to get the cell region under given a cursor coordinate.
# So we don't want to change this to the key approach, but of course we
# need to look up the row_key first now before proceeding.
# TODO: This is pre-existing method, we'll simply map the indices
# over to the row_keys for now, and likely provide a new means of
row_key = self._row_locations.get_key(row_index)
if row_key not in self.rows:
"""Get the region of the cell at the given spatial coordinate."""
valid_row = 0 <= row_index < len(self.rows)
valid_column = 0 <= column_index < len(self.columns)
valid_cell = valid_row and valid_column
if not valid_cell:
return Region(0, 0, 0, 0)
row_key = self._row_locations.get_key(row_index)
row = self.rows[row_key]
x = sum(column.render_width for column in self.columns[:column_index])
width = self.columns[column_index].render_width
# The x-coordinate of a cell is the sum of widths of cells to the left.
x = sum(column.render_width for column in self._ordered_columns[:column_index])
column_key = self._column_locations.get_key(column_index)
width = self.columns[column_key].render_width
height = row.height
y = row.y
if self.show_header:
@@ -457,12 +458,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_row_region(self, row_index: int) -> Region:
"""Get the region of the row at the given index."""
rows = self.rows
if row_index < 0 or row_index >= len(rows):
valid_row = 0 <= row_index < len(rows)
if not valid_row:
return Region(0, 0, 0, 0)
row_key = self._row_locations.get_key(row_index)
row = rows[row_key]
row_width = sum(column.render_width for column in self.columns)
row_width = sum(column.render_width for column in self.columns.values())
y = row.y
if self.show_header:
y += self.header_height
@@ -472,11 +474,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_column_region(self, column_index: int) -> Region:
"""Get the region of the column at the given index."""
columns = self.columns
if column_index < 0 or column_index >= len(columns):
valid_column = 0 <= column_index < len(columns)
if not valid_column:
return Region(0, 0, 0, 0)
x = sum(column.render_width for column in self.columns[:column_index])
width = columns[column_index].render_width
x = sum(column.render_width for column in self._ordered_columns[:column_index])
column_key = self._column_locations.get_key(column_index)
width = columns[column_key].render_width
header_height = self.header_height if self.show_header else 0
height = len(self._y_offsets) + header_height
full_column_region = Region(x, 0, width, height)
@@ -518,13 +522,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
text_label = Text.from_markup(label) if isinstance(label, str) else label
column_key = ColumnKey(key)
column_index = len(self.columns)
content_width = measure(self.app.console, text_label, 1)
if width is None:
column = Column(
column_key,
text_label,
content_width,
index=len(self.columns),
index=column_index,
content_width=content_width,
auto_width=True,
)
@@ -534,10 +539,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
text_label,
width,
content_width=content_width,
index=len(self.columns),
index=column_index,
)
self.columns.append(column)
self.columns[column_key] = column
self._column_locations[column_key] = column_index
self._require_update_dimensions = True
self.check_idle()
@@ -675,6 +681,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
region = region.translate(-self.scroll_offset)
self.refresh(region)
@property
def _ordered_columns(self) -> list[Column]:
column_indices = range(len(self.columns))
column_keys = [
self._column_locations.get_key(index) for index in column_indices
]
ordered_columns = [self.columns.get(key) for key in column_keys]
return ordered_columns
def _get_row_renderables(self, row_index: int) -> list[RenderableType]:
"""Get renderables for the given row.
@@ -686,18 +701,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
"""
if row_index == -1:
row = [column.label for column in self.columns]
row = [column.label for column in self._ordered_columns]
return row
row_key = self._row_locations.get_key(row_index)
data = self.data.get(row_key)
row = self.data.get(row_key)
empty = Text()
if data is None:
if row is None:
return [empty for _ in self.columns]
else:
return [
Text() if datum is None else default_cell_formatter(datum) or empty
for datum, _ in zip_longest(data, range(len(self.columns)))
for datum, _ in zip_longest(row, range(len(self.columns)))
]
def _render_cell(
@@ -761,7 +776,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _render_line_in_row(
self,
row_index: int,
row_key: RowKey,
line_no: int,
base_style: Style,
cursor_location: Coordinate,
@@ -770,7 +785,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
"""Render a row in to lines for each cell.
Args:
row_index: Index of the row.
row_key: The identifying key for this row.
line_no: Line number (y-coordinate) within row. 0 is the first strip of
cells in the row, line_no=1 is the next, and so on...
base_style: Base style of row.
@@ -784,7 +799,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
show_cursor = self.show_cursor
cache_key = (
row_index,
row_key,
line_no,
base_style,
cursor_location,
@@ -821,11 +836,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
else:
return False
row_index = self._row_locations.get(row_key, -1)
if self.fixed_columns:
fixed_style = self.get_component_styles("datatable--fixed").rich_style
fixed_style += Style.from_meta({"fixed": True})
fixed_row = []
for column in self.columns[: self.fixed_columns]:
for column in self._ordered_columns[: self.fixed_columns]:
cell_location = Coordinate(row_index, column.index)
fixed_cell_lines = render_cell(
row_index,
@@ -841,7 +857,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
else:
fixed_row = []
if row_index == -1:
if row_key is None:
row_style = self.get_component_styles("datatable--header").rich_style
else:
if self.zebra_stripes:
@@ -853,7 +869,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_style = base_style
scrollable_row = []
for column in self.columns:
for column in self.columns.values():
cell_location = Coordinate(row_index, column.index)
cell_lines = render_cell(
row_index,
@@ -869,27 +885,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._row_render_cache[cache_key] = row_pair
return row_pair
def _get_offsets(self, y: int) -> tuple[int, int]:
"""Get row number and line offset for a given line.
def _get_offsets(self, y: int) -> tuple[RowKey | None, int]:
"""Get row key and line offset for a given line.
Args:
y: Y coordinate relative to DataTable top.
Returns:
Line number and line offset within cell.
Row key and line (y) offset within cell.
"""
header_height = self.header_height
y_offsets = self._y_offsets
if self.show_header:
if y < header_height:
return -1, y
return None, y
y -= header_height
if y > len(y_offsets):
raise LookupError("Y coord {y!r} is greater than total height")
row_key, y_offset_in_row = y_offsets[y]
row_index = self._row_locations.get(row_key)
return row_index, y_offset_in_row
return y_offsets[y]
def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
"""Render a line in to a list of segments.
@@ -907,7 +921,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
width = self.size.width
try:
row_index, line_no = self._get_offsets(y)
row_key, y_offset_in_row = self._get_offsets(y)
except LookupError:
return Strip.blank(width, base_style)
@@ -927,14 +941,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
return self._line_cache[cache_key]
fixed, scrollable = self._render_line_in_row(
row_index,
line_no,
row_key,
y_offset_in_row,
base_style,
cursor_location=self.cursor_cell,
hover_location=self.hover_cell,
)
fixed_width = sum(
column.render_width for column in self.columns[: self.fixed_columns]
column.render_width
for column in self._ordered_columns[: self.fixed_columns]
)
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
@@ -982,7 +997,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
for row_index in range(self.fixed_rows)
if row_index in self.rows
)
left = sum(column.render_width for column in self.columns[: self.fixed_columns])
left = sum(
column.render_width
for column in self._ordered_columns[: self.fixed_columns]
)
return Spacing(top, 0, 0, left)
def _scroll_cursor_into_view(self, animate: bool = False) -> None:

View File

@@ -58,6 +58,7 @@ async def test_datatable_message_emission():
# therefore no highlighted cells), but then a row was added, and
# so the cell at (0, 0) became highlighted.
expected_messages.append("CellHighlighted")
await pilot.pause(2 / 100)
assert messages == expected_messages
# Pressing Enter when the cursor is on a cell emits a CellSelected
@@ -74,6 +75,7 @@ async def test_datatable_message_emission():
# Switch over to the row cursor... should emit a `RowHighlighted`
table.cursor_type = "row"
expected_messages.append("RowHighlighted")
await pilot.pause(2 / 100)
assert messages == expected_messages
# Select the row...
@@ -85,6 +87,7 @@ async def test_datatable_message_emission():
# Switching to the column cursor emits a `ColumnHighlighted`
table.cursor_type = "column"
expected_messages.append("ColumnHighlighted")
await pilot.pause(2 / 100)
assert messages == expected_messages
# Select the column...
@@ -112,6 +115,7 @@ async def test_datatable_message_emission():
# message should be emitted for highlighting the cell.
table.show_cursor = True
expected_messages.append("CellHighlighted")
await pilot.pause(2 / 100)
assert messages == expected_messages
# Likewise, if the cursor_type is "none", and we change the
@@ -213,26 +217,28 @@ async def test_column_labels() -> None:
async with app.run_test():
table = app.query_one(DataTable)
table.add_columns("1", "2", "3")
assert [col.label for col in table.columns] == [Text("1"), Text("2"), Text("3")]
actual_labels = [col.label for col in table.columns.values()]
expected_labels = [Text("1"), Text("2"), Text("3")]
assert actual_labels == expected_labels
async def test_column_widths() -> None:
app = DataTableApp()
async with app.run_test() as pilot:
table = app.query_one(DataTable)
table.add_columns("foo", "bar")
foo, bar = table.add_columns("foo", "bar")
assert table.columns[0].width == 3
assert table.columns[1].width == 3
assert table.columns[foo].width == 3
assert table.columns[bar].width == 3
table.add_row("Hello", "World!")
await pilot.pause()
assert table.columns[0].content_width == 5
assert table.columns[1].content_width == 6
assert table.columns[foo].content_width == 5
assert table.columns[bar].content_width == 6
table.add_row("Hello World!!!", "fo")
await pilot.pause()
assert table.columns[0].content_width == 14
assert table.columns[1].content_width == 6
assert table.columns[foo].content_width == 14
assert table.columns[bar].content_width == 6
def test_get_cell_value_returns_value_at_cell():