Extract common coordinate validation logic into method in DataTable

This commit is contained in:
Darren Burns
2023-02-02 13:09:11 +00:00
parent 77b94b005c
commit 990a6311bc
2 changed files with 84 additions and 31 deletions

View File

@@ -86,8 +86,8 @@ class ColumnKey(StringKey):
class CellKey(NamedTuple): class CellKey(NamedTuple):
row_key: RowKey row_key: RowKey | str
column_key: ColumnKey column_key: ColumnKey | str
def default_cell_formatter(obj: object) -> RenderableType | None: def default_cell_formatter(obj: object) -> RenderableType | None:
@@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
update_width: Whether to resize the column width to accommodate update_width: Whether to resize the column width to accommodate
for the new cell content. for the new cell content.
""" """
# TODO: Validate coordinate and raise exception if not self.is_valid_coordinate(coordinate):
raise CellDoesNotExist()
row_key, column_key = self.coordinate_to_cell_key(coordinate) row_key, column_key = self.coordinate_to_cell_key(coordinate)
self.update_cell(row_key, column_key, value, update_width=update_width) self.update_cell(row_key, column_key, value, update_width=update_width)
@@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._clear_caches() self._clear_caches()
def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
self.refresh_cell(*old) self.refresh_coordinate(old)
self.refresh_cell(*value) self.refresh_coordinate(value)
def watch_cursor_coordinate( def watch_cursor_coordinate(
self, old_coordinate: Coordinate, new_coordinate: Coordinate self, old_coordinate: Coordinate, new_coordinate: Coordinate
@@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# Refresh the old and the new cell, and emit the appropriate # Refresh the old and the new cell, and emit the appropriate
# message to tell users of the newly highlighted row/cell/column. # message to tell users of the newly highlighted row/cell/column.
if self.cursor_type == "cell": if self.cursor_type == "cell":
self.refresh_cell(*old_coordinate) self.refresh_coordinate(old_coordinate)
self._highlight_coordinate(new_coordinate) self._highlight_coordinate(new_coordinate)
elif self.cursor_type == "row": elif self.cursor_type == "row":
self.refresh_row(old_coordinate.row) self.refresh_row(old_coordinate.row)
@@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _highlight_coordinate(self, coordinate: Coordinate) -> None: def _highlight_coordinate(self, coordinate: Coordinate) -> None:
"""Apply highlighting to the cell at the coordinate, and emit event.""" """Apply highlighting to the cell at the coordinate, and emit event."""
self.refresh_cell(*coordinate) self.refresh_coordinate(coordinate)
try: try:
cell_value = self.get_value_at(coordinate) cell_value = self.get_value_at(coordinate)
except CellDoesNotExist: except CellDoesNotExist:
@@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Returns: Returns:
The key of the cell currently occupying this coordinate. The key of the cell currently occupying this coordinate.
""" """
if not self.is_valid_coordinate(coordinate):
raise CellDoesNotExist()
row_index, column_index = coordinate row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index) row_key = self._row_locations.get_key(row_index)
column_key = self._column_locations.get_key(column_index) column_key = self._column_locations.get_key(column_index)
@@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# Refresh cells that were previously impacted by the cursor # Refresh cells that were previously impacted by the cursor
# but may no longer be. # but may no longer be.
row_index, column_index = self.cursor_coordinate
if old == "cell": if old == "cell":
self.refresh_cell(row_index, column_index) self.refresh_coordinate(self.cursor_coordinate)
elif old == "row": elif old == "row":
row_index, _ = self.cursor_coordinate
self.refresh_row(row_index) self.refresh_row(row_index)
elif old == "column": elif old == "column":
_, column_index = self.cursor_coordinate
self.refresh_column(column_index) self.refresh_column(column_index)
self._scroll_cursor_into_view() self._scroll_cursor_into_view()
@@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._total_row_height + header_height, self._total_row_height + header_height,
) )
def _get_cell_region(self, row_index: int, column_index: int) -> Region: def _get_cell_region(self, coordinate: Coordinate) -> Region:
"""Get the region of the cell at the given spatial coordinate.""" """Get the region of the cell at the given spatial coordinate."""
valid_row = 0 <= row_index < len(self.rows) if not self.is_valid_coordinate(coordinate):
valid_column = 0 <= column_index < len(self.columns)
valid_cell = valid_row and valid_column
if not valid_cell:
return Region(0, 0, 0, 0) return Region(0, 0, 0, 0)
row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index) row_key = self._row_locations.get_key(row_index)
row = self.rows[row_key] row = self.rows[row_key]
@@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_row_region(self, row_index: int) -> Region: def _get_row_region(self, row_index: int) -> Region:
"""Get the region of the row at the given index.""" """Get the region of the row at the given index."""
rows = self.rows if not self.is_valid_row_index(row_index):
valid_row = 0 <= row_index < len(rows)
if not valid_row:
return Region(0, 0, 0, 0) return Region(0, 0, 0, 0)
rows = self.rows
row_key = self._row_locations.get_key(row_index) row_key = self._row_locations.get_key(row_index)
row = rows[row_key] row = rows[row_key]
row_width = sum(column.render_width for column in self.columns.values()) row_width = sum(column.render_width for column in self.columns.values())
@@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_column_region(self, column_index: int) -> Region: def _get_column_region(self, column_index: int) -> Region:
"""Get the region of the column at the given index.""" """Get the region of the column at the given index."""
columns = self.columns if not self.is_valid_column_index(column_index):
valid_column = 0 <= column_index < len(columns)
if not valid_column:
return Region(0, 0, 0, 0) return Region(0, 0, 0, 0)
columns = self.columns
x = sum(column.render_width for column in self.ordered_columns[:column_index]) x = sum(column.render_width for column in self.ordered_columns[:column_index])
column_key = self._column_locations.get_key(column_index) column_key = self._column_locations.get_key(column_index)
width = columns[column_key].render_width width = columns[column_key].render_width
@@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._updated_cells.clear() self._updated_cells.clear()
self._update_column_widths(updated_columns) self._update_column_widths(updated_columns)
def refresh_cell(self, row_index: int, column_index: int) -> None: def refresh_coordinate(self, coordinate: Coordinate) -> None:
"""Refresh a cell. """Refresh the cell at a coordinate.
Args: Args:
row_index: Row index. coordinate: The coordinate to refresh.
column_index: Column index.
""" """
if row_index < 0 or column_index < 0: if not self.is_valid_coordinate(coordinate):
return return
region = self._get_cell_region(row_index, column_index) region = self._get_cell_region(coordinate)
self._refresh_region(region) self._refresh_region(region)
def refresh_row(self, row_index: int) -> None: def refresh_row(self, row_index: int) -> None:
@@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args: Args:
row_index: The index of the row to refresh. row_index: The index of the row to refresh.
""" """
if row_index < 0 or row_index >= len(self.rows): if not self.is_valid_row_index(row_index):
return return
region = self._get_row_region(row_index) region = self._get_row_region(row_index)
@@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args: Args:
column_index: The index of the column to refresh. column_index: The index of the column to refresh.
""" """
if column_index < 0 or column_index >= len(self.columns): if not self.is_valid_column_index(column_index):
return return
region = self._get_column_region(column_index) region = self._get_column_region(column_index)
@@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
region = region.translate(-self.scroll_offset) region = region.translate(-self.scroll_offset)
self.refresh(region) self.refresh(region)
def is_valid_row_index(self, row_index: int) -> bool:
"""Return a boolean indicating whether the row_index is within table bounds.
Args:
row_index: The row index to check.
Returns:
True if the row index is within the bounds of the table.
"""
return 0 <= row_index < len(self.rows)
def is_valid_column_index(self, column_index: int) -> bool:
"""Return a boolean indicating whether the column_index is within table bounds.
Args:
column_index: The column index to check.
Returns:
True if the column index is within the bounds of the table.
"""
return 0 <= column_index < len(self.columns)
def is_valid_coordinate(self, coordinate: Coordinate) -> bool:
"""Return a boolean indicating whether the given coordinate is within table bounds.
Args:
coordinate: The coordinate to validate.
Returns:
True if the coordinate is within the bounds of the table.
"""
row_index, column_index = coordinate
return self.is_valid_row_index(row_index) and self.is_valid_column_index(
column_index
)
@property @property
def ordered_columns(self) -> list[Column]: def ordered_columns(self) -> list[Column]:
"""The list of Columns in the DataTable, ordered as they currently appear on screen.""" """The list of Columns in the DataTable, ordered as they currently appear on screen."""
@@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
x, y, width, height = self._get_column_region(self.cursor_column) x, y, width, height = self._get_column_region(self.cursor_column)
region = Region(x, int(self.scroll_y) + top, width, height - top) region = Region(x, int(self.scroll_y) + top, width, height - top)
else: else:
region = self._get_cell_region(self.cursor_row, self.cursor_column) region = self._get_cell_region(self.cursor_coordinate)
self.scroll_to_region(region, animate=animate, spacing=fixed_offset) self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
@@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
elif cursor_type == "row": elif cursor_type == "row":
self.refresh_row(self.hover_row) self.refresh_row(self.hover_row)
elif cursor_type == "cell": elif cursor_type == "cell":
self.refresh_cell(*self.hover_coordinate) self.refresh_coordinate(self.hover_coordinate)
def on_click(self, event: events.Click) -> None: def on_click(self, event: events.Click) -> None:
self._set_hover_cursor(True) self._set_hover_cursor(True)

View File

@@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist():
table.update_cell("INVALID", "CELL", "Value") table.update_cell("INVALID", "CELL", "Value")
# TODO: Test update coordinate async def test_update_coordinate_coordinate_exists():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
column_0, column_1 = table.add_columns("A", "B")
row_0, *_ = table.add_rows(ROWS)
table.update_coordinate(Coordinate(0, 1), "newvalue")
assert table.get_cell_value(row_0, column_1) == "newvalue"
async def test_update_coordinate_coordinate_doesnt_exist():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
table.add_columns("A", "B")
table.add_rows(ROWS)
with pytest.raises(CellDoesNotExist):
table.update_coordinate(Coordinate(999, 999), "newvalue")
def test_key_equals_equivalent_string(): def test_key_equals_equivalent_string():