Extract common coordinate validation logic into method in DataTable

This commit is contained in:
Darren Burns
2023-02-02 13:09:11 +00:00
parent 77b94b005c
commit 990a6311bc
2 changed files with 84 additions and 31 deletions

View File

@@ -86,8 +86,8 @@ class ColumnKey(StringKey):
class CellKey(NamedTuple):
row_key: RowKey
column_key: ColumnKey
row_key: RowKey | str
column_key: ColumnKey | str
def default_cell_formatter(obj: object) -> RenderableType | None:
@@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
update_width: Whether to resize the column width to accommodate
for the new cell content.
"""
# TODO: Validate coordinate and raise exception
if not self.is_valid_coordinate(coordinate):
raise CellDoesNotExist()
row_key, column_key = self.coordinate_to_cell_key(coordinate)
self.update_cell(row_key, column_key, value, update_width=update_width)
@@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._clear_caches()
def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
self.refresh_cell(*old)
self.refresh_cell(*value)
self.refresh_coordinate(old)
self.refresh_coordinate(value)
def watch_cursor_coordinate(
self, old_coordinate: Coordinate, new_coordinate: Coordinate
@@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# Refresh the old and the new cell, and emit the appropriate
# message to tell users of the newly highlighted row/cell/column.
if self.cursor_type == "cell":
self.refresh_cell(*old_coordinate)
self.refresh_coordinate(old_coordinate)
self._highlight_coordinate(new_coordinate)
elif self.cursor_type == "row":
self.refresh_row(old_coordinate.row)
@@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _highlight_coordinate(self, coordinate: Coordinate) -> None:
"""Apply highlighting to the cell at the coordinate, and emit event."""
self.refresh_cell(*coordinate)
self.refresh_coordinate(coordinate)
try:
cell_value = self.get_value_at(coordinate)
except CellDoesNotExist:
@@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Returns:
The key of the cell currently occupying this coordinate.
"""
if not self.is_valid_coordinate(coordinate):
raise CellDoesNotExist()
row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index)
column_key = self._column_locations.get_key(column_index)
@@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# Refresh cells that were previously impacted by the cursor
# but may no longer be.
row_index, column_index = self.cursor_coordinate
if old == "cell":
self.refresh_cell(row_index, column_index)
self.refresh_coordinate(self.cursor_coordinate)
elif old == "row":
row_index, _ = self.cursor_coordinate
self.refresh_row(row_index)
elif old == "column":
_, column_index = self.cursor_coordinate
self.refresh_column(column_index)
self._scroll_cursor_into_view()
@@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._total_row_height + header_height,
)
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
def _get_cell_region(self, coordinate: Coordinate) -> Region:
"""Get the region of the cell at the given spatial coordinate."""
valid_row = 0 <= row_index < len(self.rows)
valid_column = 0 <= column_index < len(self.columns)
valid_cell = valid_row and valid_column
if not valid_cell:
if not self.is_valid_coordinate(coordinate):
return Region(0, 0, 0, 0)
row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index)
row = self.rows[row_key]
@@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_row_region(self, row_index: int) -> Region:
"""Get the region of the row at the given index."""
rows = self.rows
valid_row = 0 <= row_index < len(rows)
if not valid_row:
if not self.is_valid_row_index(row_index):
return Region(0, 0, 0, 0)
rows = self.rows
row_key = self._row_locations.get_key(row_index)
row = rows[row_key]
row_width = sum(column.render_width for column in self.columns.values())
@@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_column_region(self, column_index: int) -> Region:
"""Get the region of the column at the given index."""
columns = self.columns
valid_column = 0 <= column_index < len(columns)
if not valid_column:
if not self.is_valid_column_index(column_index):
return Region(0, 0, 0, 0)
columns = self.columns
x = sum(column.render_width for column in self.ordered_columns[:column_index])
column_key = self._column_locations.get_key(column_index)
width = columns[column_key].render_width
@@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._updated_cells.clear()
self._update_column_widths(updated_columns)
def refresh_cell(self, row_index: int, column_index: int) -> None:
"""Refresh a cell.
def refresh_coordinate(self, coordinate: Coordinate) -> None:
"""Refresh the cell at a coordinate.
Args:
row_index: Row index.
column_index: Column index.
coordinate: The coordinate to refresh.
"""
if row_index < 0 or column_index < 0:
if not self.is_valid_coordinate(coordinate):
return
region = self._get_cell_region(row_index, column_index)
region = self._get_cell_region(coordinate)
self._refresh_region(region)
def refresh_row(self, row_index: int) -> None:
@@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args:
row_index: The index of the row to refresh.
"""
if row_index < 0 or row_index >= len(self.rows):
if not self.is_valid_row_index(row_index):
return
region = self._get_row_region(row_index)
@@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args:
column_index: The index of the column to refresh.
"""
if column_index < 0 or column_index >= len(self.columns):
if not self.is_valid_column_index(column_index):
return
region = self._get_column_region(column_index)
@@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
region = region.translate(-self.scroll_offset)
self.refresh(region)
def is_valid_row_index(self, row_index: int) -> bool:
"""Return a boolean indicating whether the row_index is within table bounds.
Args:
row_index: The row index to check.
Returns:
True if the row index is within the bounds of the table.
"""
return 0 <= row_index < len(self.rows)
def is_valid_column_index(self, column_index: int) -> bool:
"""Return a boolean indicating whether the column_index is within table bounds.
Args:
column_index: The column index to check.
Returns:
True if the column index is within the bounds of the table.
"""
return 0 <= column_index < len(self.columns)
def is_valid_coordinate(self, coordinate: Coordinate) -> bool:
"""Return a boolean indicating whether the given coordinate is within table bounds.
Args:
coordinate: The coordinate to validate.
Returns:
True if the coordinate is within the bounds of the table.
"""
row_index, column_index = coordinate
return self.is_valid_row_index(row_index) and self.is_valid_column_index(
column_index
)
@property
def ordered_columns(self) -> list[Column]:
"""The list of Columns in the DataTable, ordered as they currently appear on screen."""
@@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
x, y, width, height = self._get_column_region(self.cursor_column)
region = Region(x, int(self.scroll_y) + top, width, height - top)
else:
region = self._get_cell_region(self.cursor_row, self.cursor_column)
region = self._get_cell_region(self.cursor_coordinate)
self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
@@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
elif cursor_type == "row":
self.refresh_row(self.hover_row)
elif cursor_type == "cell":
self.refresh_cell(*self.hover_coordinate)
self.refresh_coordinate(self.hover_coordinate)
def on_click(self, event: events.Click) -> None:
self._set_hover_cursor(True)

View File

@@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist():
table.update_cell("INVALID", "CELL", "Value")
# TODO: Test update coordinate
async def test_update_coordinate_coordinate_exists():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
column_0, column_1 = table.add_columns("A", "B")
row_0, *_ = table.add_rows(ROWS)
table.update_coordinate(Coordinate(0, 1), "newvalue")
assert table.get_cell_value(row_0, column_1) == "newvalue"
async def test_update_coordinate_coordinate_doesnt_exist():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
table.add_columns("A", "B")
table.add_rows(ROWS)
with pytest.raises(CellDoesNotExist):
table.update_coordinate(Coordinate(999, 999), "newvalue")
def test_key_equals_equivalent_string():