From 95b52eef0dffcdf6a94d81cd940e259f6bcc2891 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 13:03:36 +0000 Subject: [PATCH 01/29] Refresh column widths on idle --- src/textual/widgets/_data_table.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 799403c85..ae912fbbd 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -258,7 +258,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._require_update_dimensions: bool = False self._new_rows: set[RowKey] = set() - self._updated_columns: set[ColumnKey] = set() + self._updated_cells: set[CellKey] = set() """Track which cells were updated, so that we can refresh them once on idle""" self.show_header = show_header @@ -325,19 +325,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Recalculate widths if necessary if update_width: - self._updated_columns.add(column_key) + self._updated_cells.add(CellKey(row_key, column_key)) + self._require_update_dimensions = True - self._require_update_dimensions = True self.refresh() - def update_coordinate(self, coordinate: Coordinate, value: CellType) -> None: + def update_coordinate( + self, coordinate: Coordinate, value: CellType, *, update_width: bool = False + ) -> None: row, column = coordinate row_key = self._row_locations.get_key(row) column_key = self._column_locations.get_key(column) - value = Text.from_markup(value) if isinstance(value, str) else value - self.data[row_key][column_key] = value - self._update_count += 1 - self.refresh_cell(row, column) + self.update_cell(row_key, column_key, value, update_width=update_width) def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]: """For a given column key, return the cells in that column in order""" @@ -488,13 +487,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "column": self._highlight_column(column_index) - def _update_column_widths(self, column_keys: set[ColumnKey]) -> None: - for column_key in column_keys: + def _update_column_widths(self, updated_cells: set[CellKey]) -> None: + for row_key, column_key in updated_cells: column = self.columns.get(column_key) console = self.app.console label_width = measure(console, column.label, 1) content_width = column.content_width - new_content_width = measure(console, value, 1) + cell_value = self.data[row_key][column_key] + new_content_width = measure(console, cell_value, 1) if new_content_width < content_width: cells_in_column = self._get_cells_in_column(column_key) @@ -724,9 +724,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Add the new rows *before* updating the column widths, since # cells in a new row may influence the final width of a column self._update_dimensions(new_rows) - if self._updated_columns: - updated_columns = self._updated_columns.copy() - self._updated_columns.clear() + if self._updated_cells: + updated_columns = self._updated_cells.copy() + self._updated_cells.clear() self._update_column_widths(updated_columns) self.refresh() From 8f928f4b768b0cfaf429706f93c20b121d13e1c0 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 13:05:02 +0000 Subject: [PATCH 02/29] Import optimising --- src/textual/widgets/_data_table.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 88aeac5e6..2ac8b1b20 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -13,8 +13,6 @@ from typing import ( NamedTuple, Callable, Sequence, - Type, - Optional, ) import rich.repr @@ -30,9 +28,8 @@ from .._cache import LRUCache from .._segment_tools import line_crop from .._two_way_dict import TwoWayDict from .._types import SegmentLines -from .._typing import Literal, TypeAlias -from ..binding import Binding from .._typing import Literal +from .._typing import TypeAlias from ..binding import Binding, BindingType from ..coordinate import Coordinate from ..geometry import Region, Size, Spacing, clamp From 48488e7402b886507ac4c30ea5e19c7cfa0c6e69 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 13:29:18 +0000 Subject: [PATCH 03/29] Add cell_key to CellHighlighted event --- src/textual/widgets/_data_table.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 2ac8b1b20..e23ff887f 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -255,13 +255,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: value: The value in the highlighted cell. coordinate: The coordinate of the highlighted cell. + cell_key: The key for the highlighted cell. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value self.coordinate: Coordinate = coordinate + self.cell_key: CellKey = cell_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: @@ -576,7 +582,26 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # In that case, there's nothing for us to do here. return else: - self.emit_no_wait(DataTable.CellHighlighted(self, cell_value, coordinate)) + cell_key = self.coordinate_to_cell_key(coordinate) + self.emit_no_wait( + DataTable.CellHighlighted( + self, cell_value, coordinate=coordinate, cell_key=cell_key + ) + ) + + def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey: + """Return the key for the cell currently occupying this coordinate in the DataTable + + Args: + coordinate: The coordinate to exam the current cell key of. + + Returns: + The key of the cell currently occupying this coordinate. + """ + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + column_key = self._column_locations.get_key(column_index) + return CellKey(row_key, column_key) def _highlight_row(self, row_index: int) -> None: """Apply highlighting to the row at the given index, and emit event.""" From abd35436fb94982a280381ad27319a6c41f40832 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 13:34:13 +0000 Subject: [PATCH 04/29] Some refactoring, and add cell_key to DataTable.CellSelected --- src/textual/widgets/_data_table.py | 81 ++++++++++++++++-------------- tests/test_data_table.py | 16 +++--- 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index e23ff887f..9847be8a8 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -239,10 +239,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): show_cursor = Reactive(True) cursor_type = Reactive(CELL) - cursor_cell: Reactive[Coordinate] = Reactive( + cursor_coordinate: Reactive[Coordinate] = Reactive( Coordinate(0, 0), repaint=False, always_update=True ) - hover_cell: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) + hover_coordinate: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) class CellHighlighted(Message, bubble=True): """Emitted when the cursor moves to highlight a new cell. @@ -274,6 +274,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class CellSelected(Message, bubble=True): """Emitted by the `DataTable` widget when a cell is selected. @@ -284,19 +285,26 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: value: The value in the cell that was selected. coordinate: The coordinate of the cell that was selected. + cell_key: The key for the selected cell. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value self.coordinate: Coordinate = coordinate + self.cell_key: CellKey = cell_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class RowHighlighted(Message, bubble=True): """Emitted when a row is highlighted. This message is only emitted when the @@ -420,19 +428,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def hover_row(self) -> int: - return self.hover_cell.row + return self.hover_coordinate.row @property def hover_column(self) -> int: - return self.hover_cell.column + return self.hover_coordinate.column @property def cursor_row(self) -> int: - return self.cursor_cell.row + return self.cursor_coordinate.row @property def cursor_column(self) -> int: - return self.cursor_cell.column + return self.cursor_coordinate.column @property def row_count(self) -> int: @@ -492,8 +500,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row = self.data.get(row_key) yield row.get(column_key) - def get_cell_value(self, coordinate: Coordinate) -> CellType: - """Get the value from the cell at the given coordinate. + def get_value_at(self, coordinate: Coordinate) -> CellType: + """Get the value from the cell occupying the given coordinate. Args: coordinate: The coordinate to retrieve the value from. @@ -504,11 +512,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Raises: CellDoesNotExist: If there is no cell with the given coordinate. """ - # TODO: Rename to get_value_at()? - # We need to clearly distinguish between coordinates and cell keys - row_index, column_index = coordinate - row_key = self._row_locations.get_key(row_index) - column_key = self._column_locations.get_key(column_index) + row_key, column_key = self.coordinate_to_cell_key(coordinate) try: cell_value = self.data[row_key][column_key] except KeyError: @@ -537,7 +541,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # emit the appropriate [Row|Column|Cell]Highlighted event. self._scroll_cursor_into_view(animate=False) if self.cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_cell(self.cursor_coordinate) elif self.cursor_type == "row": self._highlight_row(self.cursor_row) elif self.cursor_type == "column": @@ -576,7 +580,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Apply highlighting to the cell at the coordinate, and emit event.""" self.refresh_cell(*coordinate) try: - cell_value = self.get_cell_value(coordinate) + cell_value = self.get_value_at(coordinate) except CellDoesNotExist: # The cell may not exist e.g. when the table is cleared. # In that case, there's nothing for us to do here. @@ -632,7 +636,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh cells that were previously impacted by the cursor # but may no longer be. - row_index, column_index = self.cursor_cell + row_index, column_index = self.cursor_coordinate if old == "cell": self.refresh_cell(row_index, column_index) elif old == "row": @@ -643,11 +647,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._scroll_cursor_into_view() def _highlight_cursor(self) -> None: - row_index, column_index = self.cursor_cell + row_index, column_index = self.cursor_coordinate cursor_type = self.cursor_type # Apply the highlighting to the newly relevant cells if cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_cell(self.cursor_coordinate) elif cursor_type == "row": self._highlight_row(row_index) elif cursor_type == "column": @@ -756,8 +760,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.columns.clear() self._line_no = 0 self._require_update_dimensions = True - self.cursor_cell = Coordinate(0, 0) - self.hover_cell = Coordinate(0, 0) + self.cursor_coordinate = Coordinate(0, 0) + self.hover_coordinate = Coordinate(0, 0) self.refresh() def add_column( @@ -834,7 +838,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.rows[row_key] = Row(row_key, height) self._new_rows.add(row_key) self._require_update_dimensions = True - self.cursor_cell = self.cursor_cell + self.cursor_coordinate = self.cursor_coordinate # If a position has opened for the cursor to appear, where it previously # could not (e.g. when there's no data in the table), then a highlighted @@ -1212,8 +1216,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x1, x2, width, - self.cursor_cell, - self.hover_cell, + self.cursor_coordinate, + self.hover_coordinate, base_style, self.cursor_type, self._show_hover_cursor, @@ -1226,8 +1230,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_key, y_offset_in_row, base_style, - cursor_location=self.cursor_cell, - hover_location=self.hover_cell, + cursor_location=self.cursor_coordinate, + hover_location=self.hover_coordinate, ) fixed_width = sum( column.render_width @@ -1268,7 +1272,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): meta = event.style.meta if meta and self.show_cursor and self.cursor_type != "none": try: - self.hover_cell = Coordinate(meta["row"], meta["column"]) + self.hover_coordinate = Coordinate(meta["row"], meta["column"]) except KeyError: pass @@ -1351,7 +1355,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "row": self.refresh_row(self.hover_row) elif cursor_type == "cell": - self.refresh_cell(*self.hover_cell) + self.refresh_cell(*self.hover_coordinate) def on_click(self, event: events.Click) -> None: self._set_hover_cursor(True) @@ -1360,7 +1364,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._emit_selected_message() meta = self.get_style_at(event.x, event.y).meta if meta: - self.cursor_cell = Coordinate(meta["row"], meta["column"]) + self.cursor_coordinate = Coordinate(meta["row"], meta["column"]) self._scroll_cursor_into_view(animate=True) event.stop() @@ -1368,7 +1372,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.up() + self.cursor_coordinate = self.cursor_coordinate.up() self._scroll_cursor_into_view() else: # If the cursor doesn't move up (e.g. column cursor can't go up), @@ -1379,7 +1383,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.down() + self.cursor_coordinate = self.cursor_coordinate.down() self._scroll_cursor_into_view() else: super().action_scroll_down() @@ -1388,7 +1392,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.right() + self.cursor_coordinate = self.cursor_coordinate.right() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_right() @@ -1397,7 +1401,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.left() + self.cursor_coordinate = self.cursor_coordinate.left() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_left() @@ -1409,19 +1413,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _emit_selected_message(self): """Emit the appropriate message for a selection based on the `cursor_type`.""" - cursor_cell = self.cursor_cell + cursor_coordinate = self.cursor_coordinate cursor_type = self.cursor_type if cursor_type == "cell": self.emit_no_wait( DataTable.CellSelected( self, - self.get_cell_value(cursor_cell), - cursor_cell, + self.get_value_at(cursor_coordinate), + coordinate=cursor_coordinate, + cell_key=self.coordinate_to_cell_key(cursor_coordinate), ) ) elif cursor_type == "row": - row, _ = cursor_cell + row, _ = cursor_coordinate self.emit_no_wait(DataTable.RowSelected(self, row)) elif cursor_type == "column": - _, column = cursor_cell + _, column = cursor_coordinate self.emit_no_wait(DataTable.ColumnSelected(self, column)) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 1c53f9b6a..28de49c75 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -193,19 +193,19 @@ async def test_clear(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Add some data and update cursor positions table.add_column("Column0") table.add_rows([["Row0"], ["Row1"], ["Row2"]]) - table.cursor_cell = Coordinate(1, 0) - table.hover_cell = Coordinate(2, 0) + table.cursor_coordinate = Coordinate(1, 0) + table.hover_coordinate = Coordinate(2, 0) # Ensure the cursor positions are reset to origin on clear() table.clear() - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Ensure that the table has been cleared assert table.data == {} @@ -253,7 +253,7 @@ async def test_get_cell_value_returns_value_at_cell(): table = app.query_one(DataTable) table.add_columns("A", "B") table.add_rows(ROWS) - assert table.get_cell_value(Coordinate(0, 0)) == Text("0/0") + assert table.get_value_at(Coordinate(0, 0)) == Text("0/0") async def test_get_cell_value_exception(): @@ -263,7 +263,7 @@ async def test_get_cell_value_exception(): table.add_columns("A", "B") table.add_rows(ROWS) with pytest.raises(CellDoesNotExist): - table.get_cell_value(Coordinate(9999, 0)) + table.get_value_at(Coordinate(9999, 0)) def test_key_equals_equivalent_string(): From e02ef1e22ce0ea998728ce3035c916220b3a1517 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 13:42:53 +0000 Subject: [PATCH 05/29] Update watcher/validator names in DataTable --- src/textual/widgets/_data_table.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 9847be8a8..9dbebc892 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -556,11 +556,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def watch_zebra_stripes(self, zebra_stripes: bool) -> None: self._clear_caches() - def watch_hover_cell(self, old: Coordinate, value: Coordinate) -> None: + def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: self.refresh_cell(*old) self.refresh_cell(*value) - def watch_cursor_cell( + def watch_cursor_coordinate( self, old_coordinate: Coordinate, new_coordinate: Coordinate ) -> None: if old_coordinate != new_coordinate: @@ -620,11 +620,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if column_index < len(self.columns): self.emit_no_wait(DataTable.ColumnHighlighted(self, column_index)) - def validate_cursor_cell(self, value: Coordinate) -> Coordinate: - return self._clamp_cursor_cell(value) + def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate: + return self._clamp_cursor_coordinate(value) - def _clamp_cursor_cell(self, cursor_cell: Coordinate) -> Coordinate: - row, column = cursor_cell + def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate: + row, column = coordinate row = clamp(row, 0, self.row_count - 1) column = clamp(column, self.fixed_columns, len(self.columns) - 1) return Coordinate(row, column) From f97cdd679747633732a3805c50992eb2b25fa3cf Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 15:18:33 +0000 Subject: [PATCH 06/29] Remove redundant attribute. Add more DataTable docstrings. --- src/textual/widgets/_data_table.py | 64 ++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 9dbebc892..88315f855 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -392,16 +392,24 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): ) -> None: super().__init__(name=name, id=id, classes=classes) self.data: dict[RowKey, dict[ColumnKey, CellType]] = {} + """Contains the cells of the table, indexed by row key and column key. + The final positioning of a cell on screen cannot be determined solely by this + structure. Instead, we must check _row_locations and _column_locations to find + where each cell currently resides in space.""" - # Metadata on rows and columns in the table self.columns: dict[ColumnKey, Column] = {} + """Metadata about the columns of the table, indexed by their key.""" self.rows: dict[RowKey, Row] = {} + """Metadata about the rows of the table, indexed by their key.""" # Keep tracking of key -> index for rows/cols. These allow us to retrieve, # given a row or column key, the index that row or column is currently present at, - # and mean that rows and columns are location independent - they can move around. - self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + # and mean that rows and columns are location independent - they can move around + # without requiring us to modify the underlying data. self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) + """Maps row keys to row indices which represent row order.""" + self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + """Maps column keys to column indices which represent column order.""" self._row_render_cache: LRUCache[ RowCacheKey, tuple[SegmentLines, SegmentLines] @@ -409,22 +417,31 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) - self._line_no = 0 self._require_update_dimensions: bool = False - + """Set to re-calculate dimensions on idle.""" self._new_rows: set[RowKey] = set() + """Tracking newly added rows to be used in re-calculation of dimensions on idle.""" self._updated_cells: set[CellKey] = set() - """Track which cells were updated, so that we can refresh them once on idle""" + """Track which cells were updated, so that we can refresh them once on idle.""" self.show_header = show_header - self.fixed_rows = fixed_rows - self.fixed_columns = fixed_columns - self.zebra_stripes = zebra_stripes + """Show/hide the header row (the row of column labels).""" self.header_height = header_height + """The height of the header row (the row of column labels).""" + self.fixed_rows = fixed_rows + """The number of rows to fix (prevented from scrolling).""" + self.fixed_columns = fixed_columns + """The number of columns to fix (prevented from scrolling).""" + self.zebra_stripes = zebra_stripes + """Apply zebra-stripe effect on row backgrounds (light, dark, light, dark, ...).""" self.show_cursor = show_cursor + """Show/hide both the keyboard and hover cursor.""" self._show_hover_cursor = False + """Used to hide the mouse hover cursor when the user uses the keyboard.""" self._update_count = 0 + """The number of update operations that have occurred. Used for cache invalidation.""" self._header_row_key = RowKey() + """The header is a special row which is not part of the data. This key is used to retrieve it.""" @property def hover_row(self) -> int: @@ -488,9 +505,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def update_coordinate( self, coordinate: Coordinate, value: CellType, *, update_width: bool = False ) -> None: - row, column = coordinate - row_key = self._row_locations.get_key(row) - column_key = self._column_locations.get_key(column) + """Update the content inside the cell currently occupying the given coordinate. + + Args: + coordinate: The coordinate to update the cell at. + value: The new value to place inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + """ + row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]: @@ -758,7 +781,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.rows.clear() if columns: self.columns.clear() - self._line_no = 0 self._require_update_dimensions = True self.cursor_coordinate = Coordinate(0, 0) self.hover_coordinate = Coordinate(0, 0) @@ -887,18 +909,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return row_keys def on_idle(self) -> None: + # Add the new rows *before* updating the column widths, since + # cells in a new row may influence the final width of a column if self._require_update_dimensions: self._require_update_dimensions = False new_rows = self._new_rows.copy() self._new_rows.clear() - # Add the new rows *before* updating the column widths, since - # cells in a new row may influence the final width of a column self._update_dimensions(new_rows) - if self._updated_cells: - updated_columns = self._updated_cells.copy() - self._updated_cells.clear() - self._update_column_widths(updated_columns) - self.refresh() + + if self._updated_cells: + # Cell contents have already been updated at this point. + # Now we only need to worry about measuring column widths. + updated_columns = self._updated_cells.copy() + self._updated_cells.clear() + self._update_column_widths(updated_columns) def refresh_cell(self, row_index: int, column_index: int) -> None: """Refresh a cell. From 25abe4dbdf1e35e18ee4bdfa6cac1056492bb930 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 15:30:23 +0000 Subject: [PATCH 07/29] Expose ordered rows and ordered columns publically --- src/textual/widgets/_data_table.py | 32 ++++++++++++++---------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 88315f855..8b4a91477 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -466,7 +466,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def _y_offsets(self) -> list[tuple[RowKey, int]]: y_offsets: list[tuple[RowKey, int]] = [] - for row in self._ordered_rows: + for row in self.ordered_rows: row_key = row.key row_height = row.height y_offsets += [(row_key, y) for y in range(row_height)] @@ -518,7 +518,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]: """For a given column key, return the cells in that column in order""" - for row_metadata in self._ordered_rows: + for row_metadata in self.ordered_rows: row_key = row_metadata.key row = self.data.get(row_key) yield row.get(column_key) @@ -703,7 +703,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if row_index is None: continue for column, renderable in zip( - self._ordered_columns, self._get_row_renderables(row_index) + self.ordered_columns, self._get_row_renderables(row_index) ): content_width = measure(self.app.console, renderable, 1) column.content_width = max(column.content_width, content_width) @@ -728,11 +728,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row = self.rows[row_key] # The x-coordinate of a cell is the sum of widths of cells to the left. - x = sum(column.render_width for column in self._ordered_columns[:column_index]) + x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = self.columns[column_key].render_width height = row.height - y = sum(ordered_row.height for ordered_row in self._ordered_rows[:row_index]) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height cell_region = Region(x, y, width, height) @@ -748,7 +748,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_key = self._row_locations.get_key(row_index) row = rows[row_key] row_width = sum(column.render_width for column in self.columns.values()) - y = sum(ordered_row.height for ordered_row in self._ordered_rows[:row_index]) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height row_region = Region(0, y, row_width, row.height) @@ -761,7 +761,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if not valid_column: return Region(0, 0, 0, 0) - x = sum(column.render_width for column in self._ordered_columns[:column_index]) + x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = columns[column_key].render_width header_height = self.header_height if self.show_header else 0 @@ -855,7 +855,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_locations[row_key] = row_index self.data[row_key] = { column.key: Text(cell) if isinstance(cell, str) else cell - for column, cell in zip_longest(self._ordered_columns, cells) + for column, cell in zip_longest(self.ordered_columns, cells) } self.rows[row_key] = Row(row_key, height) self._new_rows.add(row_key) @@ -970,7 +970,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh(region) @property - def _ordered_columns(self) -> list[Column]: + def ordered_columns(self) -> list[Column]: column_indices = range(len(self.columns)) column_keys = [ self._column_locations.get_key(index) for index in column_indices @@ -979,7 +979,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return ordered_columns @property - def _ordered_rows(self) -> list[Row]: + def ordered_rows(self) -> list[Row]: row_indices = range(self.row_count) ordered_rows = [] for row_index in row_indices: @@ -999,7 +999,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """ # TODO: We have quite a few back and forward key/index conversions, could probably reduce them - ordered_columns = self._ordered_columns + ordered_columns = self.ordered_columns if row_index == -1: row = [column.label for column in ordered_columns] return row @@ -1150,7 +1150,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): fixed_style += Style.from_meta({"fixed": True}) fixed_row = [] for column_index, column in enumerate( - self._ordered_columns[: self.fixed_columns] + self.ordered_columns[: self.fixed_columns] ): cell_location = Coordinate(row_index, column_index) fixed_cell_lines = render_cell( @@ -1179,7 +1179,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_style = base_style scrollable_row = [] - for column_index, column in enumerate(self._ordered_columns): + for column_index, column in enumerate(self.ordered_columns): cell_location = Coordinate(row_index, column_index) cell_lines = render_cell( row_index, @@ -1258,8 +1258,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): hover_location=self.hover_coordinate, ) fixed_width = sum( - column.render_width - for column in self._ordered_columns[: self.fixed_columns] + column.render_width for column in self.ordered_columns[: self.fixed_columns] ) fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] @@ -1308,8 +1307,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if row_index in self.rows ) left = sum( - column.render_width - for column in self._ordered_columns[: self.fixed_columns] + column.render_width for column in self.ordered_columns[: self.fixed_columns] ) return Spacing(top, 0, 0, left) From 0b2b7a964628b58056c46c33a91c116a300ca35e Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 31 Jan 2023 16:43:33 +0000 Subject: [PATCH 08/29] Docstring improvements --- src/textual/widgets/_data_table.py | 31 ++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 8b4a91477..e393545ee 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -530,16 +530,30 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): coordinate: The coordinate to retrieve the value from. Returns: - The value of the cell. + The value of the cell at the coordinate. Raises: CellDoesNotExist: If there is no cell with the given coordinate. """ row_key, column_key = self.coordinate_to_cell_key(coordinate) + return self.get_cell_value(row_key, column_key) + + def get_cell_value(self, row_key: RowKey, column_key: ColumnKey) -> CellType: + """Given a row key and column key, return the value of the corresponding cell. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The value of the cell identified by the row and column keys. + """ try: cell_value = self.data[row_key][column_key] except KeyError: - raise CellDoesNotExist(f"No cell exists at {coordinate!r}") from None + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) return cell_value def _clear_caches(self) -> None: @@ -549,6 +563,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._styles_cache.clear() def get_row_height(self, row_key: RowKey) -> int: + """Given a row key, return the height of that row in terminal cells. + + Args: + row_key: The key of the row. + + Returns: + The height of the row, measured in terminal character cells. + """ if row_key is self._header_row_key: return self.header_height return self.rows[row_key].height @@ -1216,7 +1238,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return y_offsets[y] def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: - """Render a line in to a list of segments. + """Render a (possibly cropped) line in to a Strip (a list of segments + representing a horizontal line). Args: y: Y coordinate of line @@ -1225,7 +1248,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): base_style: Style to apply to line. Returns: - List of segments for rendering. + The Strip which represents this cropped line. """ width = self.size.width From 655b2b3ea7d4a932d747f999e4e3d67ae64704c5 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 10:57:03 +0000 Subject: [PATCH 09/29] Docstring updates --- src/textual/widgets/_data_table.py | 39 ++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index e393545ee..2bc819ab7 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -445,22 +445,27 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def hover_row(self) -> int: + """The index of the row that the mouse cursor is currently hovering above.""" return self.hover_coordinate.row @property def hover_column(self) -> int: + """The index of the column that the mouse cursor is currently hovering above.""" return self.hover_coordinate.column @property def cursor_row(self) -> int: + """The index of the row that the DataTable cursor is currently on.""" return self.cursor_coordinate.row @property def cursor_column(self) -> int: + """The index of the column that the DataTable cursor is currently on.""" return self.cursor_coordinate.column @property def row_count(self) -> int: + """The number of rows currently present in the DataTable.""" return len(self.rows) @property @@ -586,7 +591,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # emit the appropriate [Row|Column|Cell]Highlighted event. self._scroll_cursor_into_view(animate=False) if self.cursor_type == "cell": - self._highlight_cell(self.cursor_coordinate) + self._highlight_coordinate(self.cursor_coordinate) elif self.cursor_type == "row": self._highlight_row(self.cursor_row) elif self.cursor_type == "column": @@ -613,7 +618,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # message to tell users of the newly highlighted row/cell/column. if self.cursor_type == "cell": self.refresh_cell(*old_coordinate) - self._highlight_cell(new_coordinate) + self._highlight_coordinate(new_coordinate) elif self.cursor_type == "row": self.refresh_row(old_coordinate.row) self._highlight_row(new_coordinate.row) @@ -621,7 +626,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh_column(old_coordinate.column) self._highlight_column(new_coordinate.column) - def _highlight_cell(self, coordinate: Coordinate) -> None: + def _highlight_coordinate(self, coordinate: Coordinate) -> None: """Apply highlighting to the cell at the coordinate, and emit event.""" self.refresh_cell(*coordinate) try: @@ -669,6 +674,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return self._clamp_cursor_coordinate(value) def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate: + """Clamp a coordinate such that it falls within the boundaries of the table.""" row, column = coordinate row = clamp(row, 0, self.row_count - 1) column = clamp(column, self.fixed_columns, len(self.columns) - 1) @@ -692,11 +698,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._scroll_cursor_into_view() def _highlight_cursor(self) -> None: + """Applies the appropriate highlighting and raises the appropriate + [Row|Column|Cell]Highlighted event for the given cursor coordinate + and cursor type.""" row_index, column_index = self.cursor_coordinate cursor_type = self.cursor_type # Apply the highlighting to the newly relevant cells if cursor_type == "cell": - self._highlight_cell(self.cursor_coordinate) + self._highlight_coordinate(self.cursor_coordinate) elif cursor_type == "row": self._highlight_row(row_index) elif cursor_type == "column": @@ -931,9 +940,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return row_keys def on_idle(self) -> None: - # Add the new rows *before* updating the column widths, since - # cells in a new row may influence the final width of a column + """Runs when the message pump is empty, and so we use this for + some expensive calculations like re-computing dimensions of the + whole DataTable and re-computing column widths after some cells + have been updated. This is more efficient in the case of high + frequency updates, ensuring we only do expensive computations once.""" if self._require_update_dimensions: + # Add the new rows *before* updating the column widths, since + # cells in a new row may influence the final width of a column self._require_update_dimensions = False new_rows = self._new_rows.copy() self._new_rows.clear() @@ -993,6 +1007,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def ordered_columns(self) -> list[Column]: + """The list of Columns in the DataTable, ordered as they currently appear on screen.""" column_indices = range(len(self.columns)) column_keys = [ self._column_locations.get_key(index) for index in column_indices @@ -1002,6 +1017,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def ordered_rows(self) -> list[Row]: + """The list of Rows in the DataTable, ordered as they currently appear on screen.""" row_indices = range(self.row_count) ordered_rows = [] for row_index in row_indices: @@ -1113,12 +1129,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): cursor_location: Coordinate, hover_location: Coordinate, ) -> tuple[SegmentLines, SegmentLines]: - """Render a row in to lines for each cell. + """Render a single line from a row in the DataTable. Args: row_key: The identifying key for this row. line_no: Line number (y-coordinate) within row. 0 is the first strip of - cells in the row, line_no=1 is the next, and so on... + cells in the row, line_no=1 is the next line in the row, and so on... base_style: Base style of row. cursor_location: The location of the cursor in the DataTable. hover_location: The location of the hover cursor in the DataTable. @@ -1314,6 +1330,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return self._render_line(y, scroll_x, scroll_x + width, self.rich_style) def on_mouse_move(self, event: events.MouseMove): + """If the hover cursor is visible, display it by extracting the row + and column metadata from the segments present in the cells.""" self._set_hover_cursor(True) meta = event.style.meta if meta and self.show_cursor and self.cursor_type != "none": @@ -1323,6 +1341,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): pass def _get_fixed_offset(self) -> Spacing: + """Calculate the "fixed offset", that is the space to the top and left + that is occupied by fixed rows and columns respectively. Fixed rows and columns + are rows and columns that do not participate in scrolling.""" top = self.header_height if self.show_header else 0 top += sum( self.rows[self._row_locations.get_key(row_index)].height @@ -1370,6 +1391,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh() def _scroll_cursor_into_view(self, animate: bool = False) -> None: + """When the cursor is at a boundary of the DataTable and moves out + of view, this method handles scrolling to ensure it remains visible.""" fixed_offset = self._get_fixed_offset() top, _, _, left = fixed_offset From 07e964d2ba6b67c39660ced64042f75b6758c113 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 11:14:08 +0000 Subject: [PATCH 10/29] More docstrings for the DataTable, new private property refactor for total_row_height --- src/textual/widgets/_data_table.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 2bc819ab7..d2cabe73f 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -470,6 +470,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): @property def _y_offsets(self) -> list[tuple[RowKey, int]]: + """Contains a 2-tuple for each line (not row!) of the DataTable. Given a y-coordinate, + we can index into this list to find which row that y-coordinate lands on, and the + y-offset *within* that row. The length of the returned list is therefore the total + height of all rows within the DataTable.""" y_offsets: list[tuple[RowKey, int]] = [] for row in self.ordered_rows: row_key = row.key @@ -477,6 +481,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): y_offsets += [(row_key, y) for y in range(row_height)] return y_offsets + @property + def _total_row_height(self) -> int: + """The total height of all rows within the DataTable""" + return len(self._y_offsets) + def update_cell( self, row_key: RowKey | str, @@ -744,7 +753,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): header_height = self.header_height if self.show_header else 0 self.virtual_size = Size( total_width, - len(self._y_offsets) + header_height, + self._total_row_height + header_height, ) def _get_cell_region(self, row_index: int, column_index: int) -> Region: @@ -796,7 +805,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): column_key = self._column_locations.get_key(column_index) width = columns[column_key].render_width header_height = self.header_height if self.show_header else 0 - height = len(self._y_offsets) + header_height + height = self._total_row_height + header_height full_column_region = Region(x, 0, width, height) return full_column_region From cc3d744168e3098a05500731fae722c666649d6c Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 11:15:31 +0000 Subject: [PATCH 11/29] Add row_key to RowHighlighted event in DataTable --- src/textual/widgets/_data_table.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index d2cabe73f..8c27cf9ce 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -315,13 +315,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): cursor_row: The y-coordinate of the cursor that highlighted the row. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + self.row_key: RowKey = row_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class RowSelected(Message, bubble=True): """Emitted when a row is selected. This message is only emitted when the From bf42ac94f7f03bf89a8a712c1f6eedd4f5e8cc3a Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 11:34:39 +0000 Subject: [PATCH 12/29] Ensure row_key is included in RowHighlighted event --- src/textual/widgets/_data_table.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 8c27cf9ce..cf5146787 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -416,8 +416,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_render_cache: LRUCache[ RowCacheKey, tuple[SegmentLines, SegmentLines] ] = LRUCache(1000) + """For each row (a row can have a height of multiple lines), we maintain a cache + of the fixed and scrollable lines within that row to minimise how often we need to + re-render it.""" self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) + """Cache for individual cells.""" self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) + """Cache for lines within rows.""" self._require_update_dimensions: bool = False """Set to re-calculate dimensions on idle.""" @@ -673,7 +678,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh_row(row_index) is_valid_row = row_index < len(self.data) if is_valid_row: - self.emit_no_wait(DataTable.RowHighlighted(self, row_index)) + row_key = self._row_locations.get_key(row_index) + self.emit_no_wait(DataTable.RowHighlighted(self, row_index, row_key)) def _highlight_column(self, column_index: int) -> None: """Apply highlighting to the column at the given index, and emit event.""" From c9629b1755cf01d11ead0a80faf5cf6814e0e1a7 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 13:53:48 +0000 Subject: [PATCH 13/29] Ensure keys are included in emitted events from DataTable --- src/textual/widgets/_data_table.py | 38 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index cf5146787..6bec7fd11 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -313,6 +313,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_row: The y-coordinate of the cursor that highlighted the row. + row_key: The key of the row that was highlighted. """ def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: @@ -333,15 +334,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_row: The y-coordinate of the cursor that made the selection. + row_key: The key of the row that was selected. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + self.row_key: RowKey = row_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class ColumnHighlighted(Message, bubble=True): """Emitted when a column is highlighted. This message is only emitted when the @@ -351,15 +355,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_column: The x-coordinate of the column that was highlighted. + column_key: The key of the column that was highlighted. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + self.column_key = column_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key class ColumnSelected(Message, bubble=True): """Emitted when a column is selected. This message is only emitted when the @@ -369,15 +378,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_column: The x-coordinate of the column that was selected. + column_key: The key of the column that was selected. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + self.column_key = column_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key def __init__( self, @@ -685,7 +699,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Apply highlighting to the column at the given index, and emit event.""" self.refresh_column(column_index) if column_index < len(self.columns): - self.emit_no_wait(DataTable.ColumnHighlighted(self, column_index)) + column_key = self._column_locations.get_key(column_index) + self.emit_no_wait( + DataTable.ColumnHighlighted(self, column_index, column_key) + ) def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate: return self._clamp_cursor_coordinate(value) @@ -1500,18 +1517,21 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Emit the appropriate message for a selection based on the `cursor_type`.""" cursor_coordinate = self.cursor_coordinate cursor_type = self.cursor_type + cell_key = self.coordinate_to_cell_key(cursor_coordinate) if cursor_type == "cell": self.emit_no_wait( DataTable.CellSelected( self, self.get_value_at(cursor_coordinate), coordinate=cursor_coordinate, - cell_key=self.coordinate_to_cell_key(cursor_coordinate), + cell_key=cell_key, ) ) elif cursor_type == "row": - row, _ = cursor_coordinate - self.emit_no_wait(DataTable.RowSelected(self, row)) + row_index, _ = cursor_coordinate + row_key, _ = cell_key + self.emit_no_wait(DataTable.RowSelected(self, row_index, row_key)) elif cursor_type == "column": - _, column = cursor_coordinate - self.emit_no_wait(DataTable.ColumnSelected(self, column)) + _, column_index = cursor_coordinate + _, column_key = cell_key + self.emit_no_wait(DataTable.ColumnSelected(self, column_index, column_key)) From 53685ee2b58381f5c35e11b5adc89342e2f67867 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 13:59:44 +0000 Subject: [PATCH 14/29] Docstring update in DataTable --- src/textual/widgets/_data_table.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 6bec7fd11..aa7ccaac3 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -552,7 +552,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.update_cell(row_key, column_key, value, update_width=update_width) def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]: - """For a given column key, return the cells in that column in order""" + """For a given column key, return the cells in that column in the + order they currently appear on screen.""" for row_metadata in self.ordered_rows: row_key = row_metadata.key row = self.data.get(row_key) From 67d79e16dad694b27c734d70de507b1bba5a4c86 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 14:10:01 +0000 Subject: [PATCH 15/29] Simplify _get_offsets to return header row key --- src/textual/widgets/_data_table.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index aa7ccaac3..aedd06180 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -1138,8 +1138,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if is_fixed_style: style += self.get_component_styles("datatable--cursor-fixed").rich_style - # TODO: We can hoist `row_key` lookup waaay up to do it inside `_get_offsets` - # then just pass it through to here instead of the row_index. row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) cell_cache_key = (row_key, column_key, style, cursor, hover, self._update_count) @@ -1240,7 +1238,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: fixed_row = [] - if row_key is None: + is_header_row = row_key is self._header_row_key + if is_header_row: row_style = self.get_component_styles("datatable--header").rich_style else: if self.zebra_stripes: @@ -1268,7 +1267,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_render_cache[cache_key] = row_pair return row_pair - def _get_offsets(self, y: int) -> tuple[RowKey | None, int]: + def _get_offsets(self, y: int) -> tuple[RowKey, int]: """Get row key and line offset for a given line. Args: @@ -1281,7 +1280,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): y_offsets = self._y_offsets if self.show_header: if y < header_height: - return None, y + return self._header_row_key, y y -= header_height if y > len(y_offsets): raise LookupError("Y coord {y!r} is greater than total height") From a7383e6a83823824decec010d47d51b20ebdb5dd Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 14:51:05 +0000 Subject: [PATCH 16/29] Import and export datatable utilities from public module --- src/textual/widgets/_data_table.py | 15 +++++++-------- src/textual/widgets/data_table.py | 22 ++++++++++++++++++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index aedd06180..a607b9a67 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -39,15 +39,14 @@ from ..render import measure from ..scroll_view import ScrollView from ..strip import Strip -CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]" -LineCacheKey: TypeAlias = ( +_CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]" +_LineCacheKey: TypeAlias = ( "tuple[int, int, int, int, Coordinate, Coordinate, Style, CursorType, bool, int]" ) -RowCacheKey: TypeAlias = ( +_RowCacheKey: TypeAlias = ( "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int]" ) CursorType = Literal["cell", "row", "column", "none"] -CELL: CursorType = "cell" CellType = TypeVar("CellType") @@ -237,7 +236,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): zebra_stripes = Reactive(False) header_height = Reactive(1) show_cursor = Reactive(True) - cursor_type = Reactive(CELL) + cursor_type = Reactive("cell") cursor_coordinate: Reactive[Coordinate] = Reactive( Coordinate(0, 0), repaint=False, always_update=True @@ -428,14 +427,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Maps column keys to column indices which represent column order.""" self._row_render_cache: LRUCache[ - RowCacheKey, tuple[SegmentLines, SegmentLines] + _RowCacheKey, tuple[SegmentLines, SegmentLines] ] = LRUCache(1000) """For each row (a row can have a height of multiple lines), we maintain a cache of the fixed and scrollable lines within that row to minimise how often we need to re-render it.""" - self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) + self._cell_render_cache: LRUCache[_CellCacheKey, SegmentLines] = LRUCache(10000) """Cache for individual cells.""" - self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) + self._line_cache: LRUCache[_LineCacheKey, Strip] = LRUCache(1000) """Cache for lines within rows.""" self._require_update_dimensions: bool = False diff --git a/src/textual/widgets/data_table.py b/src/textual/widgets/data_table.py index d0316f387..429724361 100644 --- a/src/textual/widgets/data_table.py +++ b/src/textual/widgets/data_table.py @@ -1,5 +1,23 @@ """Make non-widget DataTable support classes available.""" -from ._data_table import Column, Row +from ._data_table import ( + Column, + Row, + RowKey, + ColumnKey, + CellKey, + CursorType, + CellType, + CellDoesNotExist, +) -__all__ = ["Column", "Row"] +__all__ = [ + "Column", + "Row", + "RowKey", + "ColumnKey", + "CellKey", + "CursorType", + "CellType", + "CellDoesNotExist", +] From 3f463cb0ef4cae4d31e8f8f2f26f96bac96c141d Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 14:54:33 +0000 Subject: [PATCH 17/29] Store strings as strings --- src/textual/widgets/_data_table.py | 2 +- tests/test_data_table.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index a607b9a67..d70fce77d 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -919,7 +919,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Map the key of this row to its current index self._row_locations[row_key] = row_index self.data[row_key] = { - column.key: Text(cell) if isinstance(cell, str) else cell + column.key: cell for column, cell in zip_longest(self.ordered_columns, cells) } self.rows[row_key] = Row(row_key, height) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 28de49c75..23004419e 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -155,14 +155,14 @@ async def test_add_rows_user_defined_keys(): assert isinstance(algernon_key, RowKey) # Ensure the data in the table is mapped as expected - first_row = {key_a: Text(ROWS[0][0]), key_b: Text(ROWS[0][1])} + first_row = {key_a: ROWS[0][0], key_b: ROWS[0][1]} assert table.data[algernon_key] == first_row assert table.data["algernon"] == first_row - second_row = {key_a: Text(ROWS[1][0]), key_b: Text(ROWS[1][1])} + second_row = {key_a: ROWS[1][0], key_b: ROWS[1][1]} assert table.data["charlie"] == second_row - third_row = {key_a: Text(ROWS[2][0]), key_b: Text(ROWS[2][1])} + third_row = {key_a: ROWS[2][0], key_b: ROWS[2][1]} assert table.data[auto_key] == third_row first_row = Row(algernon_key, height=1) @@ -253,7 +253,7 @@ async def test_get_cell_value_returns_value_at_cell(): table = app.query_one(DataTable) table.add_columns("A", "B") table.add_rows(ROWS) - assert table.get_value_at(Coordinate(0, 0)) == Text("0/0") + assert table.get_value_at(Coordinate(0, 0)) == "0/0" async def test_get_cell_value_exception(): From c84ae53395a88609b1efabcf24541aba16882724 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 15:07:38 +0000 Subject: [PATCH 18/29] Fix docstring indentation to fix mkdocs rendering --- src/textual/widgets/_data_table.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index d70fce77d..66a59498b 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -947,8 +947,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: A list of the keys for the columns that were added. See - the `add_column` method docstring for more information on how - these keys are used. + the `add_column` method docstring for more information on how + these keys are used. """ column_keys = [] for label in labels: @@ -964,8 +964,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: A list of the keys for the rows that were added. See - the `add_row` method docstring for more information on how - these keys are used. + the `add_row` method docstring for more information on how + these keys are used. """ row_keys = [] for row in rows: From 43c2696ccfe54b7345b496d3f63f44dd8386e603 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 15:25:13 +0000 Subject: [PATCH 19/29] Small rename in DataTable utility types --- src/textual/widgets/_data_table.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 66a59498b..d54b96fc6 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -39,11 +39,11 @@ from ..render import measure from ..scroll_view import ScrollView from ..strip import Strip -_CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]" -_LineCacheKey: TypeAlias = ( +CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]" +LineCacheKey: TypeAlias = ( "tuple[int, int, int, int, Coordinate, Coordinate, Style, CursorType, bool, int]" ) -_RowCacheKey: TypeAlias = ( +RowCacheKey: TypeAlias = ( "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int]" ) CursorType = Literal["cell", "row", "column", "none"] @@ -427,14 +427,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """Maps column keys to column indices which represent column order.""" self._row_render_cache: LRUCache[ - _RowCacheKey, tuple[SegmentLines, SegmentLines] + RowCacheKey, tuple[SegmentLines, SegmentLines] ] = LRUCache(1000) """For each row (a row can have a height of multiple lines), we maintain a cache of the fixed and scrollable lines within that row to minimise how often we need to re-render it.""" - self._cell_render_cache: LRUCache[_CellCacheKey, SegmentLines] = LRUCache(10000) + self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) """Cache for individual cells.""" - self._line_cache: LRUCache[_LineCacheKey, Strip] = LRUCache(1000) + self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) """Cache for lines within rows.""" self._require_update_dimensions: bool = False @@ -974,7 +974,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return row_keys def on_idle(self) -> None: - """Runs when the message pump is empty, and so we use this for + """Runs when the message pump is empty. We use this for some expensive calculations like re-computing dimensions of the whole DataTable and re-computing column widths after some cells have been updated. This is more efficient in the case of high From fd4e13c988a3649c071c0f36d88665131dfcdfa9 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 15:43:36 +0000 Subject: [PATCH 20/29] Add tests for DataTable.get_cell_value --- tests/test_data_table.py | 53 ++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 23004419e..378de0a17 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -5,8 +5,7 @@ from textual.app import App from textual.coordinate import Coordinate from textual.message import Message from textual.widgets import DataTable -from textual.widgets._data_table import ( - StringKey, +from textual.widgets.data_table import ( CellDoesNotExist, RowKey, Row, @@ -248,6 +247,35 @@ async def test_column_widths() -> None: async def test_get_cell_value_returns_value_at_cell(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + assert table.get_cell_value("R1", "C1") == "TargetValue" + + +async def test_get_cell_value_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell_value("INVALID_ROW", "C1") + + +async def test_get_cell_value_invalid_column_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell_value("R1", "INVALID_COLUMN") + + +async def test_get_value_at_returns_value_at_cell(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) @@ -256,7 +284,7 @@ async def test_get_cell_value_returns_value_at_cell(): assert table.get_value_at(Coordinate(0, 0)) == "0/0" -async def test_get_cell_value_exception(): +async def test_get_value_at_exception(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) @@ -266,15 +294,24 @@ async def test_get_cell_value_exception(): table.get_value_at(Coordinate(9999, 0)) +# async def test_update_cell_cell_exists(): +# app = DataTableApp() +# async with app.run_test(): +# table = app.query_one(DataTable) +# table.add_column("A", key="A") +# table.add_row("1", key="1") +# assert table.get_cell_value() + + def test_key_equals_equivalent_string(): text = "Hello" - key = StringKey(text) + key = RowKey(text) assert key == text assert hash(key) == hash(text) def test_key_doesnt_match_non_equal_string(): - key = StringKey("123") + key = ColumnKey("123") text = "laksjdlaskjd" assert key != text assert hash(key) != hash(text) @@ -293,9 +330,9 @@ def test_key_string_lookup(): # in tests how we intend for the keys to work for cache lookups. dictionary = { "foo": "bar", - StringKey("hello"): "world", + RowKey("hello"): "world", } assert dictionary["foo"] == "bar" - assert dictionary[StringKey("foo")] == "bar" + assert dictionary[RowKey("foo")] == "bar" assert dictionary["hello"] == "world" - assert dictionary[StringKey("hello")] == "world" + assert dictionary[RowKey("hello")] == "world" From 23a34030cd1dcec13b21cd3d451eca42d27d06bf Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 17:10:59 +0000 Subject: [PATCH 21/29] Measuring string cells correctly --- src/textual/render.py | 4 ++++ src/textual/widgets/_data_table.py | 11 ++++------- tests/test_data_table.py | 21 ++++++++++++--------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/textual/render.py b/src/textual/render.py index 8911c4263..c1003b062 100644 --- a/src/textual/render.py +++ b/src/textual/render.py @@ -1,5 +1,6 @@ from __future__ import annotations +from rich.cells import cell_len from rich.console import Console, RenderableType from rich.protocol import rich_cast @@ -22,6 +23,9 @@ def measure( Returns: Width in cells """ + if isinstance(renderable, str): + return cell_len(renderable) + width = default renderable = rich_cast(renderable) get_console_width = getattr(renderable, "__rich_measure__", None) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index d54b96fc6..a9a00fbc3 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -524,8 +524,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): update_width: Whether to resize the column width to accommodate for the new cell content. """ - value = Text.from_markup(value) if isinstance(value, str) else value - self.data[row_key][column_key] = value self._update_count += 1 @@ -752,6 +750,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): label_width = measure(console, column.label, 1) content_width = column.content_width cell_value = self.data[row_key][column_key] + new_content_width = measure(console, cell_value, 1) if new_content_width < content_width: @@ -866,15 +865,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): of its current location in the DataTable (it could have moved after being added due to sorting or insertion/deletion of other columns). """ - text_label = Text.from_markup(label) if isinstance(label, str) else label - column_key = ColumnKey(key) column_index = len(self.columns) - content_width = measure(self.app.console, text_label, 1) + content_width = measure(self.app.console, label, 1) if width is None: column = Column( column_key, - text_label, + label, content_width, content_width=content_width, auto_width=True, @@ -882,7 +879,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: column = Column( column_key, - text_label, + label, width, content_width=content_width, ) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 378de0a17..23499e3eb 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -178,7 +178,6 @@ async def test_add_columns(): assert len(table.columns) == 3 -# TODO: Ensure we can use the key to retrieve the column. async def test_add_columns_user_defined_keys(): app = DataTableApp() async with app.run_test(): @@ -223,7 +222,7 @@ async def test_column_labels() -> None: table = app.query_one(DataTable) table.add_columns("1", "2", "3") actual_labels = [col.label for col in table.columns.values()] - expected_labels = [Text("1"), Text("2"), Text("3")] + expected_labels = ["1", "2", "3"] assert actual_labels == expected_labels @@ -294,13 +293,17 @@ async def test_get_value_at_exception(): table.get_value_at(Coordinate(9999, 0)) -# async def test_update_cell_cell_exists(): -# app = DataTableApp() -# async with app.run_test(): -# table = app.query_one(DataTable) -# table.add_column("A", key="A") -# table.add_row("1", key="1") -# assert table.get_cell_value() +async def test_update_cell_cell_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + table.update_cell("1", "A", "NEW_VALUE") + assert table.get_cell_value("1", "A") == "NEW_VALUE" + + +# TODO: Test update coordinate def test_key_equals_equivalent_string(): From 77b94b005ce3251f741f30423091ab3cf1013fe8 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 1 Feb 2023 17:34:03 +0000 Subject: [PATCH 22/29] Testing case where you try to update cells which dont exist --- src/textual/widgets/_data_table.py | 12 +++++++++++- tests/test_data_table.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index a9a00fbc3..0668630c9 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -523,8 +523,17 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): value: The new value to put inside the cell. update_width: Whether to resize the column width to accommodate for the new cell content. + + Raises: + CellDoesNotExist: When the supplied `row_key` and `column_key` + cannot be found in the table. """ - self.data[row_key][column_key] = value + try: + self.data[row_key][column_key] = value + except KeyError: + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) from None self._update_count += 1 # Recalculate widths if necessary @@ -545,6 +554,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): update_width: Whether to resize the column width to accommodate for the new cell content. """ + # TODO: Validate coordinate and raise exception row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 23499e3eb..b4f20f811 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -303,6 +303,16 @@ async def test_update_cell_cell_exists(): assert table.get_cell_value("1", "A") == "NEW_VALUE" +async def test_update_cell_cell_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + with pytest.raises(CellDoesNotExist): + table.update_cell("INVALID", "CELL", "Value") + + # TODO: Test update coordinate From 990a6311bccbf2539bfcac209a0564bae7d206ba Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 13:09:11 +0000 Subject: [PATCH 23/29] Extract common coordinate validation logic into method in DataTable --- src/textual/widgets/_data_table.py | 96 ++++++++++++++++++++---------- tests/test_data_table.py | 19 +++++- 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 0668630c9..e82b76384 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -86,8 +86,8 @@ class ColumnKey(StringKey): class CellKey(NamedTuple): - row_key: RowKey - column_key: ColumnKey + row_key: RowKey | str + column_key: ColumnKey | str def default_cell_formatter(obj: object) -> RenderableType | None: @@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): update_width: Whether to resize the column width to accommodate for the new cell content. """ - # TODO: Validate coordinate and raise exception + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist() + row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) @@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._clear_caches() def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: - self.refresh_cell(*old) - self.refresh_cell(*value) + self.refresh_coordinate(old) + self.refresh_coordinate(value) def watch_cursor_coordinate( self, old_coordinate: Coordinate, new_coordinate: Coordinate @@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh the old and the new cell, and emit the appropriate # message to tell users of the newly highlighted row/cell/column. if self.cursor_type == "cell": - self.refresh_cell(*old_coordinate) + self.refresh_coordinate(old_coordinate) self._highlight_coordinate(new_coordinate) elif self.cursor_type == "row": self.refresh_row(old_coordinate.row) @@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _highlight_coordinate(self, coordinate: Coordinate) -> None: """Apply highlighting to the cell at the coordinate, and emit event.""" - self.refresh_cell(*coordinate) + self.refresh_coordinate(coordinate) try: cell_value = self.get_value_at(coordinate) except CellDoesNotExist: @@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: The key of the cell currently occupying this coordinate. """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist() row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) @@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh cells that were previously impacted by the cursor # but may no longer be. - row_index, column_index = self.cursor_coordinate if old == "cell": - self.refresh_cell(row_index, column_index) + self.refresh_coordinate(self.cursor_coordinate) elif old == "row": + row_index, _ = self.cursor_coordinate self.refresh_row(row_index) elif old == "column": + _, column_index = self.cursor_coordinate self.refresh_column(column_index) self._scroll_cursor_into_view() @@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._total_row_height + header_height, ) - def _get_cell_region(self, row_index: int, column_index: int) -> Region: + def _get_cell_region(self, coordinate: Coordinate) -> Region: """Get the region of the cell at the given spatial coordinate.""" - valid_row = 0 <= row_index < len(self.rows) - valid_column = 0 <= column_index < len(self.columns) - valid_cell = valid_row and valid_column - if not valid_cell: + if not self.is_valid_coordinate(coordinate): return Region(0, 0, 0, 0) + row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) row = self.rows[row_key] @@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_row_region(self, row_index: int) -> Region: """Get the region of the row at the given index.""" - rows = self.rows - valid_row = 0 <= row_index < len(rows) - if not valid_row: + if not self.is_valid_row_index(row_index): return Region(0, 0, 0, 0) + rows = self.rows row_key = self._row_locations.get_key(row_index) row = rows[row_key] row_width = sum(column.render_width for column in self.columns.values()) @@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_column_region(self, column_index: int) -> Region: """Get the region of the column at the given index.""" - columns = self.columns - valid_column = 0 <= column_index < len(columns) - if not valid_column: + if not self.is_valid_column_index(column_index): return Region(0, 0, 0, 0) + columns = self.columns x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = columns[column_key].render_width @@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._updated_cells.clear() self._update_column_widths(updated_columns) - def refresh_cell(self, row_index: int, column_index: int) -> None: - """Refresh a cell. + def refresh_coordinate(self, coordinate: Coordinate) -> None: + """Refresh the cell at a coordinate. Args: - row_index: Row index. - column_index: Column index. + coordinate: The coordinate to refresh. """ - if row_index < 0 or column_index < 0: + if not self.is_valid_coordinate(coordinate): return - region = self._get_cell_region(row_index, column_index) + region = self._get_cell_region(coordinate) self._refresh_region(region) def refresh_row(self, row_index: int) -> None: @@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: row_index: The index of the row to refresh. """ - if row_index < 0 or row_index >= len(self.rows): + if not self.is_valid_row_index(row_index): return region = self._get_row_region(row_index) @@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: column_index: The index of the column to refresh. """ - if column_index < 0 or column_index >= len(self.columns): + if not self.is_valid_column_index(column_index): return region = self._get_column_region(column_index) @@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): region = region.translate(-self.scroll_offset) self.refresh(region) + def is_valid_row_index(self, row_index: int) -> bool: + """Return a boolean indicating whether the row_index is within table bounds. + + Args: + row_index: The row index to check. + + Returns: + True if the row index is within the bounds of the table. + """ + return 0 <= row_index < len(self.rows) + + def is_valid_column_index(self, column_index: int) -> bool: + """Return a boolean indicating whether the column_index is within table bounds. + + Args: + column_index: The column index to check. + + Returns: + True if the column index is within the bounds of the table. + """ + return 0 <= column_index < len(self.columns) + + def is_valid_coordinate(self, coordinate: Coordinate) -> bool: + """Return a boolean indicating whether the given coordinate is within table bounds. + + Args: + coordinate: The coordinate to validate. + + Returns: + True if the coordinate is within the bounds of the table. + """ + row_index, column_index = coordinate + return self.is_valid_row_index(row_index) and self.is_valid_column_index( + column_index + ) + @property def ordered_columns(self) -> list[Column]: """The list of Columns in the DataTable, ordered as they currently appear on screen.""" @@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x, y, width, height = self._get_column_region(self.cursor_column) region = Region(x, int(self.scroll_y) + top, width, height - top) else: - region = self._get_cell_region(self.cursor_row, self.cursor_column) + region = self._get_cell_region(self.cursor_coordinate) self.scroll_to_region(region, animate=animate, spacing=fixed_offset) @@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "row": self.refresh_row(self.hover_row) elif cursor_type == "cell": - self.refresh_cell(*self.hover_coordinate) + self.refresh_coordinate(self.hover_coordinate) def on_click(self, event: events.Click) -> None: self._set_hover_cursor(True) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index b4f20f811..40d399f93 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist(): table.update_cell("INVALID", "CELL", "Value") -# TODO: Test update coordinate +async def test_update_coordinate_coordinate_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_0, column_1 = table.add_columns("A", "B") + row_0, *_ = table.add_rows(ROWS) + table.update_coordinate(Coordinate(0, 1), "newvalue") + assert table.get_cell_value(row_0, column_1) == "newvalue" + + +async def test_update_coordinate_coordinate_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + with pytest.raises(CellDoesNotExist): + table.update_coordinate(Coordinate(999, 999), "newvalue") def test_key_equals_equivalent_string(): From 7748b69e954f107767ad2615e6c36dbd0ba9c3f5 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 14:12:14 +0000 Subject: [PATCH 24/29] Initial unit tests around column width updates --- tests/test_data_table.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 40d399f93..fccd4e275 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,6 +1,9 @@ +import asyncio + import pytest from rich.text import Text +from textual._wait import wait_for_idle from textual.app import App from textual.coordinate import Coordinate from textual.message import Message @@ -226,21 +229,21 @@ async def test_column_labels() -> None: assert actual_labels == expected_labels -async def test_column_widths() -> None: +async def test_initial_column_widths() -> None: app = DataTableApp() - async with app.run_test() as pilot: + async with app.run_test(): table = app.query_one(DataTable) foo, bar = table.add_columns("foo", "bar") assert table.columns[foo].width == 3 assert table.columns[bar].width == 3 table.add_row("Hello", "World!") - await pilot.pause() + await wait_for_idle() assert table.columns[foo].content_width == 5 assert table.columns[bar].content_width == 6 table.add_row("Hello World!!!", "fo") - await pilot.pause() + await wait_for_idle() assert table.columns[foo].content_width == 14 assert table.columns[bar].content_width == 6 @@ -319,6 +322,10 @@ async def test_update_coordinate_coordinate_exists(): table = app.query_one(DataTable) column_0, column_1 = table.add_columns("A", "B") row_0, *_ = table.add_rows(ROWS) + + columns = table.columns + column = columns.get(column_1) + table.update_coordinate(Coordinate(0, 1), "newvalue") assert table.get_cell_value(row_0, column_1) == "newvalue" @@ -333,6 +340,31 @@ async def test_update_coordinate_coordinate_doesnt_exist(): table.update_coordinate(Coordinate(999, 999), "newvalue") +@pytest.mark.parametrize( + "label,new_value,new_content_width", + [ + # We update the value of a cell to a value longer than the initial value, + # but shorter than the column label. The column label width should be used. + ("1234567", "1234", 7), + # We update the value of a cell to a value larger than the initial value, + # so the width of the column should be increased to accommodate on idle. + ("1234567", "123456789", 9), + ], +) +async def test_update_coordinate_column_width(label, new_value, new_content_width): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + key, _ = table.add_columns(label, "Column2") + table.add_rows(ROWS) + first_column = table.columns.get(key) + + table.update_coordinate(Coordinate(0, 0), new_value, update_width=True) + await wait_for_idle() + assert first_column.content_width == new_content_width + assert first_column.render_width == new_content_width + 2 + + def test_key_equals_equivalent_string(): text = "Hello" key = RowKey(text) From 134ceffd110ea070630c54ed9611dd82bf629622 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 14:20:33 +0000 Subject: [PATCH 25/29] Testing to ensure column size calculated correctly --- tests/test_data_table.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index fccd4e275..80caf04ee 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -343,12 +343,16 @@ async def test_update_coordinate_coordinate_doesnt_exist(): @pytest.mark.parametrize( "label,new_value,new_content_width", [ - # We update the value of a cell to a value longer than the initial value, - # but shorter than the column label. The column label width should be used. + # Initial cell values are length 3. Let's update cell content and ensure + # that the width of the column is calculated given the new cell width. + # Shorter than initial cell value, larger than label => width remains same + ("A", "BB", 3), + # Larger than initial cell value, shorter than label => width remains that of label ("1234567", "1234", 7), - # We update the value of a cell to a value larger than the initial value, - # so the width of the column should be increased to accommodate on idle. - ("1234567", "123456789", 9), + # Shorter than initial cell value, shorter than label => width remains same + ("12345", "123", 5), + # Larger than initial cell value, larger than label => width updates to new cell value + ("12345", "123456789", 9), ], ) async def test_update_coordinate_column_width(label, new_value, new_content_width): From 87808c63b284d59f520d591d0e71355c89920558 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 15:29:26 +0000 Subject: [PATCH 26/29] Tidying some tests --- tests/test_data_table.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 80caf04ee..fed155f8b 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -323,9 +323,6 @@ async def test_update_coordinate_coordinate_exists(): column_0, column_1 = table.add_columns("A", "B") row_0, *_ = table.add_rows(ROWS) - columns = table.columns - column = columns.get(column_1) - table.update_coordinate(Coordinate(0, 1), "newvalue") assert table.get_cell_value(row_0, column_1) == "newvalue" @@ -343,8 +340,6 @@ async def test_update_coordinate_coordinate_doesnt_exist(): @pytest.mark.parametrize( "label,new_value,new_content_width", [ - # Initial cell values are length 3. Let's update cell content and ensure - # that the width of the column is calculated given the new cell width. # Shorter than initial cell value, larger than label => width remains same ("A", "BB", 3), # Larger than initial cell value, shorter than label => width remains that of label @@ -356,6 +351,9 @@ async def test_update_coordinate_coordinate_doesnt_exist(): ], ) async def test_update_coordinate_column_width(label, new_value, new_content_width): + # Initial cell values are length 3. Let's update cell content and ensure + # that the width of the column is correct given the new cell content widths + # and the label of the column the cell is in. app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) From 62fb9d58bdf72fba52fa4611d3afce029380078d Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 15:40:24 +0000 Subject: [PATCH 27/29] Testing conversion of coordinate to cell_key --- tests/test_data_table.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index fed155f8b..49d490a06 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -8,6 +8,7 @@ from textual.app import App from textual.coordinate import Coordinate from textual.message import Message from textual.widgets import DataTable +from textual.widgets._data_table import CellKey from textual.widgets.data_table import ( CellDoesNotExist, RowKey, @@ -367,6 +368,17 @@ async def test_update_coordinate_column_width(label, new_value, new_content_widt assert first_column.render_width == new_content_width + 2 +async def test_coordinate_to_cell_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_key, _ = table.add_columns("Column0", "Column1") + row_key = table.add_row("A", "B") + + cell_key = table.coordinate_to_cell_key(Coordinate(0, 0)) + assert cell_key == CellKey(row_key, column_key) + + def test_key_equals_equivalent_string(): text = "Hello" key = RowKey(text) From 18aaeaa2844ae586e14a041a9f134acb52cb0374 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 15:41:24 +0000 Subject: [PATCH 28/29] Add explanatory message to an exception in DataTable --- src/textual/widgets/_data_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index e82b76384..ae0c30e2d 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -693,7 +693,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): The key of the cell currently occupying this coordinate. """ if not self.is_valid_coordinate(coordinate): - raise CellDoesNotExist() + raise CellDoesNotExist(f"No cell exists at {coordinate!r}.") row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) From 998ee9b8a217fd6204b39e08defca81cb8ca952a Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 2 Feb 2023 15:44:25 +0000 Subject: [PATCH 29/29] Test to ensure correct exception raised when converting to cell key from coordinate in DataTable --- tests/test_data_table.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 49d490a06..1555646de 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -379,6 +379,14 @@ async def test_coordinate_to_cell_key(): assert cell_key == CellKey(row_key, column_key) +async def test_coordinate_to_cell_key_invalid_coordinate(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(CellDoesNotExist): + table.coordinate_to_cell_key(Coordinate(9999, 9999)) + + def test_key_equals_equivalent_string(): text = "Hello" key = RowKey(text)