diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 7e9d86416..d17fd6fe9 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -214,7 +214,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): ) -> None: super().__init__(name=name, id=id, classes=classes) - self.columns: list[Column] = [] + self.columns: dict[ColumnKey, Column] = {} self.rows: dict[RowKey, Row] = {} self.data: dict[RowKey, list[CellType]] = {} @@ -224,7 +224,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) # Maps y-coordinate (from top of table) to (row_index, y-coord within row) pairs - # TODO: Update types self._y_offsets: list[tuple[RowKey, int]] = [] self._row_render_cache: LRUCache[ tuple[RowKey, int, Style, int, int], tuple[SegmentLines, SegmentLines] @@ -421,13 +420,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Called to recalculate the virtual (scrollable) size.""" for row_index in new_rows: for column, renderable in zip( - self.columns, self._get_row_renderables(row_index) + self.columns.values(), self._get_row_renderables(row_index) ): content_width = measure(self.app.console, renderable, 1) column.content_width = max(column.content_width, content_width) self._clear_caches() - total_width = sum(column.render_width for column in self.columns) + total_width = sum(column.render_width for column in self.columns.values()) header_height = self.header_height if self.show_header else 0 self.virtual_size = Size( total_width, @@ -435,18 +434,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): ) def _get_cell_region(self, row_index: int, column_index: int) -> Region: - """Get the region of the cell at the given coordinate (row_index, column_index)""" - # This IS used to get the cell region under given a cursor coordinate. - # So we don't want to change this to the key approach, but of course we - # need to look up the row_key first now before proceeding. - # TODO: This is pre-existing method, we'll simply map the indices - # over to the row_keys for now, and likely provide a new means of - row_key = self._row_locations.get_key(row_index) - if row_key not in self.rows: + """Get the region of the cell at the given spatial coordinate.""" + valid_row = 0 <= row_index < len(self.rows) + valid_column = 0 <= column_index < len(self.columns) + valid_cell = valid_row and valid_column + if not valid_cell: return Region(0, 0, 0, 0) + + row_key = self._row_locations.get_key(row_index) row = self.rows[row_key] - x = sum(column.render_width for column in self.columns[:column_index]) - width = self.columns[column_index].render_width + + # The x-coordinate of a cell is the sum of widths of cells to the left. + x = sum(column.render_width for column in self._ordered_columns[:column_index]) + column_key = self._column_locations.get_key(column_index) + width = self.columns[column_key].render_width height = row.height y = row.y if self.show_header: @@ -457,12 +458,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_row_region(self, row_index: int) -> Region: """Get the region of the row at the given index.""" rows = self.rows - if row_index < 0 or row_index >= len(rows): + valid_row = 0 <= row_index < len(rows) + if not valid_row: return Region(0, 0, 0, 0) row_key = self._row_locations.get_key(row_index) row = rows[row_key] - row_width = sum(column.render_width for column in self.columns) + row_width = sum(column.render_width for column in self.columns.values()) y = row.y if self.show_header: y += self.header_height @@ -472,11 +474,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_column_region(self, column_index: int) -> Region: """Get the region of the column at the given index.""" columns = self.columns - if column_index < 0 or column_index >= len(columns): + valid_column = 0 <= column_index < len(columns) + if not valid_column: return Region(0, 0, 0, 0) - x = sum(column.render_width for column in self.columns[:column_index]) - width = columns[column_index].render_width + x = sum(column.render_width for column in self._ordered_columns[:column_index]) + column_key = self._column_locations.get_key(column_index) + width = columns[column_key].render_width header_height = self.header_height if self.show_header else 0 height = len(self._y_offsets) + header_height full_column_region = Region(x, 0, width, height) @@ -518,13 +522,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): text_label = Text.from_markup(label) if isinstance(label, str) else label column_key = ColumnKey(key) + column_index = len(self.columns) content_width = measure(self.app.console, text_label, 1) if width is None: column = Column( column_key, text_label, content_width, - index=len(self.columns), + index=column_index, content_width=content_width, auto_width=True, ) @@ -534,10 +539,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): text_label, width, content_width=content_width, - index=len(self.columns), + index=column_index, ) - self.columns.append(column) + self.columns[column_key] = column + self._column_locations[column_key] = column_index self._require_update_dimensions = True self.check_idle() @@ -675,6 +681,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): region = region.translate(-self.scroll_offset) self.refresh(region) + @property + def _ordered_columns(self) -> list[Column]: + column_indices = range(len(self.columns)) + column_keys = [ + self._column_locations.get_key(index) for index in column_indices + ] + ordered_columns = [self.columns.get(key) for key in column_keys] + return ordered_columns + def _get_row_renderables(self, row_index: int) -> list[RenderableType]: """Get renderables for the given row. @@ -686,18 +701,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """ if row_index == -1: - row = [column.label for column in self.columns] + row = [column.label for column in self._ordered_columns] return row row_key = self._row_locations.get_key(row_index) - data = self.data.get(row_key) + row = self.data.get(row_key) empty = Text() - if data is None: + if row is None: return [empty for _ in self.columns] else: return [ Text() if datum is None else default_cell_formatter(datum) or empty - for datum, _ in zip_longest(data, range(len(self.columns))) + for datum, _ in zip_longest(row, range(len(self.columns))) ] def _render_cell( @@ -761,7 +776,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _render_line_in_row( self, - row_index: int, + row_key: RowKey, line_no: int, base_style: Style, cursor_location: Coordinate, @@ -770,7 +785,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Render a row in to lines for each cell. Args: - row_index: Index of the row. + row_key: The identifying key for this row. line_no: Line number (y-coordinate) within row. 0 is the first strip of cells in the row, line_no=1 is the next, and so on... base_style: Base style of row. @@ -784,7 +799,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): show_cursor = self.show_cursor cache_key = ( - row_index, + row_key, line_no, base_style, cursor_location, @@ -821,11 +836,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: return False + row_index = self._row_locations.get(row_key, -1) if self.fixed_columns: fixed_style = self.get_component_styles("datatable--fixed").rich_style fixed_style += Style.from_meta({"fixed": True}) fixed_row = [] - for column in self.columns[: self.fixed_columns]: + for column in self._ordered_columns[: self.fixed_columns]: cell_location = Coordinate(row_index, column.index) fixed_cell_lines = render_cell( row_index, @@ -841,7 +857,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: fixed_row = [] - if row_index == -1: + if row_key is None: row_style = self.get_component_styles("datatable--header").rich_style else: if self.zebra_stripes: @@ -853,7 +869,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_style = base_style scrollable_row = [] - for column in self.columns: + for column in self.columns.values(): cell_location = Coordinate(row_index, column.index) cell_lines = render_cell( row_index, @@ -869,27 +885,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_render_cache[cache_key] = row_pair return row_pair - def _get_offsets(self, y: int) -> tuple[int, int]: - """Get row number and line offset for a given line. + def _get_offsets(self, y: int) -> tuple[RowKey | None, int]: + """Get row key and line offset for a given line. Args: y: Y coordinate relative to DataTable top. Returns: - Line number and line offset within cell. + Row key and line (y) offset within cell. """ header_height = self.header_height y_offsets = self._y_offsets if self.show_header: if y < header_height: - return -1, y + return None, y y -= header_height if y > len(y_offsets): raise LookupError("Y coord {y!r} is greater than total height") - row_key, y_offset_in_row = y_offsets[y] - row_index = self._row_locations.get(row_key) - return row_index, y_offset_in_row + return y_offsets[y] def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: """Render a line in to a list of segments. @@ -907,7 +921,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): width = self.size.width try: - row_index, line_no = self._get_offsets(y) + row_key, y_offset_in_row = self._get_offsets(y) except LookupError: return Strip.blank(width, base_style) @@ -927,14 +941,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return self._line_cache[cache_key] fixed, scrollable = self._render_line_in_row( - row_index, - line_no, + row_key, + y_offset_in_row, base_style, cursor_location=self.cursor_cell, hover_location=self.hover_cell, ) fixed_width = sum( - column.render_width for column in self.columns[: self.fixed_columns] + column.render_width + for column in self._ordered_columns[: self.fixed_columns] ) fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] @@ -982,7 +997,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): for row_index in range(self.fixed_rows) if row_index in self.rows ) - left = sum(column.render_width for column in self.columns[: self.fixed_columns]) + left = sum( + column.render_width + for column in self._ordered_columns[: self.fixed_columns] + ) return Spacing(top, 0, 0, left) def _scroll_cursor_into_view(self, animate: bool = False) -> None: diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 9092c7909..9bac2cf8b 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -58,6 +58,7 @@ async def test_datatable_message_emission(): # therefore no highlighted cells), but then a row was added, and # so the cell at (0, 0) became highlighted. expected_messages.append("CellHighlighted") + await pilot.pause(2 / 100) assert messages == expected_messages # Pressing Enter when the cursor is on a cell emits a CellSelected @@ -74,6 +75,7 @@ async def test_datatable_message_emission(): # Switch over to the row cursor... should emit a `RowHighlighted` table.cursor_type = "row" expected_messages.append("RowHighlighted") + await pilot.pause(2 / 100) assert messages == expected_messages # Select the row... @@ -85,6 +87,7 @@ async def test_datatable_message_emission(): # Switching to the column cursor emits a `ColumnHighlighted` table.cursor_type = "column" expected_messages.append("ColumnHighlighted") + await pilot.pause(2 / 100) assert messages == expected_messages # Select the column... @@ -112,6 +115,7 @@ async def test_datatable_message_emission(): # message should be emitted for highlighting the cell. table.show_cursor = True expected_messages.append("CellHighlighted") + await pilot.pause(2 / 100) assert messages == expected_messages # Likewise, if the cursor_type is "none", and we change the @@ -213,26 +217,28 @@ async def test_column_labels() -> None: async with app.run_test(): table = app.query_one(DataTable) table.add_columns("1", "2", "3") - assert [col.label for col in table.columns] == [Text("1"), Text("2"), Text("3")] + actual_labels = [col.label for col in table.columns.values()] + expected_labels = [Text("1"), Text("2"), Text("3")] + assert actual_labels == expected_labels async def test_column_widths() -> None: app = DataTableApp() async with app.run_test() as pilot: table = app.query_one(DataTable) - table.add_columns("foo", "bar") + foo, bar = table.add_columns("foo", "bar") - assert table.columns[0].width == 3 - assert table.columns[1].width == 3 + assert table.columns[foo].width == 3 + assert table.columns[bar].width == 3 table.add_row("Hello", "World!") await pilot.pause() - assert table.columns[0].content_width == 5 - assert table.columns[1].content_width == 6 + assert table.columns[foo].content_width == 5 + assert table.columns[bar].content_width == 6 table.add_row("Hello World!!!", "fo") await pilot.pause() - assert table.columns[0].content_width == 14 - assert table.columns[1].content_width == 6 + assert table.columns[foo].content_width == 14 + assert table.columns[bar].content_width == 6 def test_get_cell_value_returns_value_at_cell():