mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Keys for columns in the DataTable
This commit is contained in:
@@ -214,7 +214,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
) -> None:
|
||||
super().__init__(name=name, id=id, classes=classes)
|
||||
|
||||
self.columns: list[Column] = []
|
||||
self.columns: dict[ColumnKey, Column] = {}
|
||||
self.rows: dict[RowKey, Row] = {}
|
||||
self.data: dict[RowKey, list[CellType]] = {}
|
||||
|
||||
@@ -224,7 +224,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
|
||||
|
||||
# Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs
|
||||
# TODO: Update types
|
||||
self._y_offsets: list[tuple[RowKey, int]] = []
|
||||
self._row_render_cache: LRUCache[
|
||||
tuple[RowKey, int, Style, int, int], tuple[SegmentLines, SegmentLines]
|
||||
@@ -421,13 +420,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
"""Called to recalculate the virtual (scrollable) size."""
|
||||
for row_index in new_rows:
|
||||
for column, renderable in zip(
|
||||
self.columns, self._get_row_renderables(row_index)
|
||||
self.columns.values(), self._get_row_renderables(row_index)
|
||||
):
|
||||
content_width = measure(self.app.console, renderable, 1)
|
||||
column.content_width = max(column.content_width, content_width)
|
||||
|
||||
self._clear_caches()
|
||||
total_width = sum(column.render_width for column in self.columns)
|
||||
total_width = sum(column.render_width for column in self.columns.values())
|
||||
header_height = self.header_height if self.show_header else 0
|
||||
self.virtual_size = Size(
|
||||
total_width,
|
||||
@@ -435,18 +434,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
)
|
||||
|
||||
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
|
||||
"""Get the region of the cell at the given coordinate (row_index, column_index)"""
|
||||
# This IS used to get the cell region under given a cursor coordinate.
|
||||
# So we don't want to change this to the key approach, but of course we
|
||||
# need to look up the row_key first now before proceeding.
|
||||
# TODO: This is pre-existing method, we'll simply map the indices
|
||||
# over to the row_keys for now, and likely provide a new means of
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
if row_key not in self.rows:
|
||||
"""Get the region of the cell at the given spatial coordinate."""
|
||||
valid_row = 0 <= row_index < len(self.rows)
|
||||
valid_column = 0 <= column_index < len(self.columns)
|
||||
valid_cell = valid_row and valid_column
|
||||
if not valid_cell:
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
row = self.rows[row_key]
|
||||
x = sum(column.render_width for column in self.columns[:column_index])
|
||||
width = self.columns[column_index].render_width
|
||||
|
||||
# The x-coordinate of a cell is the sum of widths of cells to the left.
|
||||
x = sum(column.render_width for column in self._ordered_columns[:column_index])
|
||||
column_key = self._column_locations.get_key(column_index)
|
||||
width = self.columns[column_key].render_width
|
||||
height = row.height
|
||||
y = row.y
|
||||
if self.show_header:
|
||||
@@ -457,12 +458,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
def _get_row_region(self, row_index: int) -> Region:
|
||||
"""Get the region of the row at the given index."""
|
||||
rows = self.rows
|
||||
if row_index < 0 or row_index >= len(rows):
|
||||
valid_row = 0 <= row_index < len(rows)
|
||||
if not valid_row:
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
row = rows[row_key]
|
||||
row_width = sum(column.render_width for column in self.columns)
|
||||
row_width = sum(column.render_width for column in self.columns.values())
|
||||
y = row.y
|
||||
if self.show_header:
|
||||
y += self.header_height
|
||||
@@ -472,11 +474,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
def _get_column_region(self, column_index: int) -> Region:
|
||||
"""Get the region of the column at the given index."""
|
||||
columns = self.columns
|
||||
if column_index < 0 or column_index >= len(columns):
|
||||
valid_column = 0 <= column_index < len(columns)
|
||||
if not valid_column:
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
x = sum(column.render_width for column in self.columns[:column_index])
|
||||
width = columns[column_index].render_width
|
||||
x = sum(column.render_width for column in self._ordered_columns[:column_index])
|
||||
column_key = self._column_locations.get_key(column_index)
|
||||
width = columns[column_key].render_width
|
||||
header_height = self.header_height if self.show_header else 0
|
||||
height = len(self._y_offsets) + header_height
|
||||
full_column_region = Region(x, 0, width, height)
|
||||
@@ -518,13 +522,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
text_label = Text.from_markup(label) if isinstance(label, str) else label
|
||||
|
||||
column_key = ColumnKey(key)
|
||||
column_index = len(self.columns)
|
||||
content_width = measure(self.app.console, text_label, 1)
|
||||
if width is None:
|
||||
column = Column(
|
||||
column_key,
|
||||
text_label,
|
||||
content_width,
|
||||
index=len(self.columns),
|
||||
index=column_index,
|
||||
content_width=content_width,
|
||||
auto_width=True,
|
||||
)
|
||||
@@ -534,10 +539,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
text_label,
|
||||
width,
|
||||
content_width=content_width,
|
||||
index=len(self.columns),
|
||||
index=column_index,
|
||||
)
|
||||
|
||||
self.columns.append(column)
|
||||
self.columns[column_key] = column
|
||||
self._column_locations[column_key] = column_index
|
||||
self._require_update_dimensions = True
|
||||
self.check_idle()
|
||||
|
||||
@@ -675,6 +681,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
region = region.translate(-self.scroll_offset)
|
||||
self.refresh(region)
|
||||
|
||||
@property
|
||||
def _ordered_columns(self) -> list[Column]:
|
||||
column_indices = range(len(self.columns))
|
||||
column_keys = [
|
||||
self._column_locations.get_key(index) for index in column_indices
|
||||
]
|
||||
ordered_columns = [self.columns.get(key) for key in column_keys]
|
||||
return ordered_columns
|
||||
|
||||
def _get_row_renderables(self, row_index: int) -> list[RenderableType]:
|
||||
"""Get renderables for the given row.
|
||||
|
||||
@@ -686,18 +701,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
"""
|
||||
|
||||
if row_index == -1:
|
||||
row = [column.label for column in self.columns]
|
||||
row = [column.label for column in self._ordered_columns]
|
||||
return row
|
||||
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
data = self.data.get(row_key)
|
||||
row = self.data.get(row_key)
|
||||
empty = Text()
|
||||
if data is None:
|
||||
if row is None:
|
||||
return [empty for _ in self.columns]
|
||||
else:
|
||||
return [
|
||||
Text() if datum is None else default_cell_formatter(datum) or empty
|
||||
for datum, _ in zip_longest(data, range(len(self.columns)))
|
||||
for datum, _ in zip_longest(row, range(len(self.columns)))
|
||||
]
|
||||
|
||||
def _render_cell(
|
||||
@@ -761,7 +776,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
def _render_line_in_row(
|
||||
self,
|
||||
row_index: int,
|
||||
row_key: RowKey,
|
||||
line_no: int,
|
||||
base_style: Style,
|
||||
cursor_location: Coordinate,
|
||||
@@ -770,7 +785,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
"""Render a row in to lines for each cell.
|
||||
|
||||
Args:
|
||||
row_index: Index of the row.
|
||||
row_key: The identifying key for this row.
|
||||
line_no: Line number (y-coordinate) within row. 0 is the first strip of
|
||||
cells in the row, line_no=1 is the next, and so on...
|
||||
base_style: Base style of row.
|
||||
@@ -784,7 +799,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
show_cursor = self.show_cursor
|
||||
|
||||
cache_key = (
|
||||
row_index,
|
||||
row_key,
|
||||
line_no,
|
||||
base_style,
|
||||
cursor_location,
|
||||
@@ -821,11 +836,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
else:
|
||||
return False
|
||||
|
||||
row_index = self._row_locations.get(row_key, -1)
|
||||
if self.fixed_columns:
|
||||
fixed_style = self.get_component_styles("datatable--fixed").rich_style
|
||||
fixed_style += Style.from_meta({"fixed": True})
|
||||
fixed_row = []
|
||||
for column in self.columns[: self.fixed_columns]:
|
||||
for column in self._ordered_columns[: self.fixed_columns]:
|
||||
cell_location = Coordinate(row_index, column.index)
|
||||
fixed_cell_lines = render_cell(
|
||||
row_index,
|
||||
@@ -841,7 +857,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
else:
|
||||
fixed_row = []
|
||||
|
||||
if row_index == -1:
|
||||
if row_key is None:
|
||||
row_style = self.get_component_styles("datatable--header").rich_style
|
||||
else:
|
||||
if self.zebra_stripes:
|
||||
@@ -853,7 +869,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
row_style = base_style
|
||||
|
||||
scrollable_row = []
|
||||
for column in self.columns:
|
||||
for column in self.columns.values():
|
||||
cell_location = Coordinate(row_index, column.index)
|
||||
cell_lines = render_cell(
|
||||
row_index,
|
||||
@@ -869,27 +885,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._row_render_cache[cache_key] = row_pair
|
||||
return row_pair
|
||||
|
||||
def _get_offsets(self, y: int) -> tuple[int, int]:
|
||||
"""Get row number and line offset for a given line.
|
||||
def _get_offsets(self, y: int) -> tuple[RowKey | None, int]:
|
||||
"""Get row key and line offset for a given line.
|
||||
|
||||
Args:
|
||||
y: Y coordinate relative to DataTable top.
|
||||
|
||||
Returns:
|
||||
Line number and line offset within cell.
|
||||
Row key and line (y) offset within cell.
|
||||
"""
|
||||
header_height = self.header_height
|
||||
y_offsets = self._y_offsets
|
||||
if self.show_header:
|
||||
if y < header_height:
|
||||
return -1, y
|
||||
return None, y
|
||||
y -= header_height
|
||||
if y > len(y_offsets):
|
||||
raise LookupError("Y coord {y!r} is greater than total height")
|
||||
|
||||
row_key, y_offset_in_row = y_offsets[y]
|
||||
row_index = self._row_locations.get(row_key)
|
||||
return row_index, y_offset_in_row
|
||||
return y_offsets[y]
|
||||
|
||||
def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
|
||||
"""Render a line in to a list of segments.
|
||||
@@ -907,7 +921,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
width = self.size.width
|
||||
|
||||
try:
|
||||
row_index, line_no = self._get_offsets(y)
|
||||
row_key, y_offset_in_row = self._get_offsets(y)
|
||||
except LookupError:
|
||||
return Strip.blank(width, base_style)
|
||||
|
||||
@@ -927,14 +941,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
return self._line_cache[cache_key]
|
||||
|
||||
fixed, scrollable = self._render_line_in_row(
|
||||
row_index,
|
||||
line_no,
|
||||
row_key,
|
||||
y_offset_in_row,
|
||||
base_style,
|
||||
cursor_location=self.cursor_cell,
|
||||
hover_location=self.hover_cell,
|
||||
)
|
||||
fixed_width = sum(
|
||||
column.render_width for column in self.columns[: self.fixed_columns]
|
||||
column.render_width
|
||||
for column in self._ordered_columns[: self.fixed_columns]
|
||||
)
|
||||
|
||||
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
|
||||
@@ -982,7 +997,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
for row_index in range(self.fixed_rows)
|
||||
if row_index in self.rows
|
||||
)
|
||||
left = sum(column.render_width for column in self.columns[: self.fixed_columns])
|
||||
left = sum(
|
||||
column.render_width
|
||||
for column in self._ordered_columns[: self.fixed_columns]
|
||||
)
|
||||
return Spacing(top, 0, 0, left)
|
||||
|
||||
def _scroll_cursor_into_view(self, animate: bool = False) -> None:
|
||||
|
||||
@@ -58,6 +58,7 @@ async def test_datatable_message_emission():
|
||||
# therefore no highlighted cells), but then a row was added, and
|
||||
# so the cell at (0, 0) became highlighted.
|
||||
expected_messages.append("CellHighlighted")
|
||||
await pilot.pause(2 / 100)
|
||||
assert messages == expected_messages
|
||||
|
||||
# Pressing Enter when the cursor is on a cell emits a CellSelected
|
||||
@@ -74,6 +75,7 @@ async def test_datatable_message_emission():
|
||||
# Switch over to the row cursor... should emit a `RowHighlighted`
|
||||
table.cursor_type = "row"
|
||||
expected_messages.append("RowHighlighted")
|
||||
await pilot.pause(2 / 100)
|
||||
assert messages == expected_messages
|
||||
|
||||
# Select the row...
|
||||
@@ -85,6 +87,7 @@ async def test_datatable_message_emission():
|
||||
# Switching to the column cursor emits a `ColumnHighlighted`
|
||||
table.cursor_type = "column"
|
||||
expected_messages.append("ColumnHighlighted")
|
||||
await pilot.pause(2 / 100)
|
||||
assert messages == expected_messages
|
||||
|
||||
# Select the column...
|
||||
@@ -112,6 +115,7 @@ async def test_datatable_message_emission():
|
||||
# message should be emitted for highlighting the cell.
|
||||
table.show_cursor = True
|
||||
expected_messages.append("CellHighlighted")
|
||||
await pilot.pause(2 / 100)
|
||||
assert messages == expected_messages
|
||||
|
||||
# Likewise, if the cursor_type is "none", and we change the
|
||||
@@ -213,26 +217,28 @@ async def test_column_labels() -> None:
|
||||
async with app.run_test():
|
||||
table = app.query_one(DataTable)
|
||||
table.add_columns("1", "2", "3")
|
||||
assert [col.label for col in table.columns] == [Text("1"), Text("2"), Text("3")]
|
||||
actual_labels = [col.label for col in table.columns.values()]
|
||||
expected_labels = [Text("1"), Text("2"), Text("3")]
|
||||
assert actual_labels == expected_labels
|
||||
|
||||
|
||||
async def test_column_widths() -> None:
|
||||
app = DataTableApp()
|
||||
async with app.run_test() as pilot:
|
||||
table = app.query_one(DataTable)
|
||||
table.add_columns("foo", "bar")
|
||||
foo, bar = table.add_columns("foo", "bar")
|
||||
|
||||
assert table.columns[0].width == 3
|
||||
assert table.columns[1].width == 3
|
||||
assert table.columns[foo].width == 3
|
||||
assert table.columns[bar].width == 3
|
||||
table.add_row("Hello", "World!")
|
||||
await pilot.pause()
|
||||
assert table.columns[0].content_width == 5
|
||||
assert table.columns[1].content_width == 6
|
||||
assert table.columns[foo].content_width == 5
|
||||
assert table.columns[bar].content_width == 6
|
||||
|
||||
table.add_row("Hello World!!!", "fo")
|
||||
await pilot.pause()
|
||||
assert table.columns[0].content_width == 14
|
||||
assert table.columns[1].content_width == 6
|
||||
assert table.columns[foo].content_width == 14
|
||||
assert table.columns[bar].content_width == 6
|
||||
|
||||
|
||||
def test_get_cell_value_returns_value_at_cell():
|
||||
|
||||
Reference in New Issue
Block a user