diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 0668630c9..e82b76384 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -86,8 +86,8 @@ class ColumnKey(StringKey): class CellKey(NamedTuple): - row_key: RowKey - column_key: ColumnKey + row_key: RowKey | str + column_key: ColumnKey | str def default_cell_formatter(obj: object) -> RenderableType | None: @@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): update_width: Whether to resize the column width to accommodate for the new cell content. """ - # TODO: Validate coordinate and raise exception + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist() + row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) @@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._clear_caches() def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: - self.refresh_cell(*old) - self.refresh_cell(*value) + self.refresh_coordinate(old) + self.refresh_coordinate(value) def watch_cursor_coordinate( self, old_coordinate: Coordinate, new_coordinate: Coordinate @@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh the old and the new cell, and emit the appropriate # message to tell users of the newly highlighted row/cell/column. if self.cursor_type == "cell": - self.refresh_cell(*old_coordinate) + self.refresh_coordinate(old_coordinate) self._highlight_coordinate(new_coordinate) elif self.cursor_type == "row": self.refresh_row(old_coordinate.row) @@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _highlight_coordinate(self, coordinate: Coordinate) -> None: """Apply highlighting to the cell at the coordinate, and emit event.""" - self.refresh_cell(*coordinate) + self.refresh_coordinate(coordinate) try: cell_value = self.get_value_at(coordinate) except CellDoesNotExist: @@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: The key of the cell currently occupying this coordinate. """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist() row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) @@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh cells that were previously impacted by the cursor # but may no longer be. - row_index, column_index = self.cursor_coordinate if old == "cell": - self.refresh_cell(row_index, column_index) + self.refresh_coordinate(self.cursor_coordinate) elif old == "row": + row_index, _ = self.cursor_coordinate self.refresh_row(row_index) elif old == "column": + _, column_index = self.cursor_coordinate self.refresh_column(column_index) self._scroll_cursor_into_view() @@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._total_row_height + header_height, ) - def _get_cell_region(self, row_index: int, column_index: int) -> Region: + def _get_cell_region(self, coordinate: Coordinate) -> Region: """Get the region of the cell at the given spatial coordinate.""" - valid_row = 0 <= row_index < len(self.rows) - valid_column = 0 <= column_index < len(self.columns) - valid_cell = valid_row and valid_column - if not valid_cell: + if not self.is_valid_coordinate(coordinate): return Region(0, 0, 0, 0) + row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) row = self.rows[row_key] @@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_row_region(self, row_index: int) -> Region: """Get the region of the row at the given index.""" - rows = self.rows - valid_row = 0 <= row_index < len(rows) - if not valid_row: + if not self.is_valid_row_index(row_index): return Region(0, 0, 0, 0) + rows = self.rows row_key = self._row_locations.get_key(row_index) row = rows[row_key] row_width = sum(column.render_width for column in self.columns.values()) @@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_column_region(self, column_index: int) -> Region: """Get the region of the column at the given index.""" - columns = self.columns - valid_column = 0 <= column_index < len(columns) - if not valid_column: + if not self.is_valid_column_index(column_index): return Region(0, 0, 0, 0) + columns = self.columns x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = columns[column_key].render_width @@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._updated_cells.clear() self._update_column_widths(updated_columns) - def refresh_cell(self, row_index: int, column_index: int) -> None: - """Refresh a cell. + def refresh_coordinate(self, coordinate: Coordinate) -> None: + """Refresh the cell at a coordinate. Args: - row_index: Row index. - column_index: Column index. + coordinate: The coordinate to refresh. """ - if row_index < 0 or column_index < 0: + if not self.is_valid_coordinate(coordinate): return - region = self._get_cell_region(row_index, column_index) + region = self._get_cell_region(coordinate) self._refresh_region(region) def refresh_row(self, row_index: int) -> None: @@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: row_index: The index of the row to refresh. """ - if row_index < 0 or row_index >= len(self.rows): + if not self.is_valid_row_index(row_index): return region = self._get_row_region(row_index) @@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: column_index: The index of the column to refresh. """ - if column_index < 0 or column_index >= len(self.columns): + if not self.is_valid_column_index(column_index): return region = self._get_column_region(column_index) @@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): region = region.translate(-self.scroll_offset) self.refresh(region) + def is_valid_row_index(self, row_index: int) -> bool: + """Return a boolean indicating whether the row_index is within table bounds. + + Args: + row_index: The row index to check. + + Returns: + True if the row index is within the bounds of the table. + """ + return 0 <= row_index < len(self.rows) + + def is_valid_column_index(self, column_index: int) -> bool: + """Return a boolean indicating whether the column_index is within table bounds. + + Args: + column_index: The column index to check. + + Returns: + True if the column index is within the bounds of the table. + """ + return 0 <= column_index < len(self.columns) + + def is_valid_coordinate(self, coordinate: Coordinate) -> bool: + """Return a boolean indicating whether the given coordinate is within table bounds. + + Args: + coordinate: The coordinate to validate. + + Returns: + True if the coordinate is within the bounds of the table. + """ + row_index, column_index = coordinate + return self.is_valid_row_index(row_index) and self.is_valid_column_index( + column_index + ) + @property def ordered_columns(self) -> list[Column]: """The list of Columns in the DataTable, ordered as they currently appear on screen.""" @@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x, y, width, height = self._get_column_region(self.cursor_column) region = Region(x, int(self.scroll_y) + top, width, height - top) else: - region = self._get_cell_region(self.cursor_row, self.cursor_column) + region = self._get_cell_region(self.cursor_coordinate) self.scroll_to_region(region, animate=animate, spacing=fixed_offset) @@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "row": self.refresh_row(self.hover_row) elif cursor_type == "cell": - self.refresh_cell(*self.hover_coordinate) + self.refresh_coordinate(self.hover_coordinate) def on_click(self, event: events.Click) -> None: self._set_hover_cursor(True) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index b4f20f811..40d399f93 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist(): table.update_cell("INVALID", "CELL", "Value") -# TODO: Test update coordinate +async def test_update_coordinate_coordinate_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_0, column_1 = table.add_columns("A", "B") + row_0, *_ = table.add_rows(ROWS) + table.update_coordinate(Coordinate(0, 1), "newvalue") + assert table.get_cell_value(row_0, column_1) == "newvalue" + + +async def test_update_coordinate_coordinate_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + with pytest.raises(CellDoesNotExist): + table.update_coordinate(Coordinate(999, 999), "newvalue") def test_key_equals_equivalent_string():