mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Extract common coordinate validation logic into method in DataTable
This commit is contained in:
@@ -86,8 +86,8 @@ class ColumnKey(StringKey):
|
||||
|
||||
|
||||
class CellKey(NamedTuple):
|
||||
row_key: RowKey
|
||||
column_key: ColumnKey
|
||||
row_key: RowKey | str
|
||||
column_key: ColumnKey | str
|
||||
|
||||
|
||||
def default_cell_formatter(obj: object) -> RenderableType | None:
|
||||
@@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
update_width: Whether to resize the column width to accommodate
|
||||
for the new cell content.
|
||||
"""
|
||||
# TODO: Validate coordinate and raise exception
|
||||
if not self.is_valid_coordinate(coordinate):
|
||||
raise CellDoesNotExist()
|
||||
|
||||
row_key, column_key = self.coordinate_to_cell_key(coordinate)
|
||||
self.update_cell(row_key, column_key, value, update_width=update_width)
|
||||
|
||||
@@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._clear_caches()
|
||||
|
||||
def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
|
||||
self.refresh_cell(*old)
|
||||
self.refresh_cell(*value)
|
||||
self.refresh_coordinate(old)
|
||||
self.refresh_coordinate(value)
|
||||
|
||||
def watch_cursor_coordinate(
|
||||
self, old_coordinate: Coordinate, new_coordinate: Coordinate
|
||||
@@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
# Refresh the old and the new cell, and emit the appropriate
|
||||
# message to tell users of the newly highlighted row/cell/column.
|
||||
if self.cursor_type == "cell":
|
||||
self.refresh_cell(*old_coordinate)
|
||||
self.refresh_coordinate(old_coordinate)
|
||||
self._highlight_coordinate(new_coordinate)
|
||||
elif self.cursor_type == "row":
|
||||
self.refresh_row(old_coordinate.row)
|
||||
@@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
def _highlight_coordinate(self, coordinate: Coordinate) -> None:
|
||||
"""Apply highlighting to the cell at the coordinate, and emit event."""
|
||||
self.refresh_cell(*coordinate)
|
||||
self.refresh_coordinate(coordinate)
|
||||
try:
|
||||
cell_value = self.get_value_at(coordinate)
|
||||
except CellDoesNotExist:
|
||||
@@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
Returns:
|
||||
The key of the cell currently occupying this coordinate.
|
||||
"""
|
||||
if not self.is_valid_coordinate(coordinate):
|
||||
raise CellDoesNotExist()
|
||||
row_index, column_index = coordinate
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
column_key = self._column_locations.get_key(column_index)
|
||||
@@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
# Refresh cells that were previously impacted by the cursor
|
||||
# but may no longer be.
|
||||
row_index, column_index = self.cursor_coordinate
|
||||
if old == "cell":
|
||||
self.refresh_cell(row_index, column_index)
|
||||
self.refresh_coordinate(self.cursor_coordinate)
|
||||
elif old == "row":
|
||||
row_index, _ = self.cursor_coordinate
|
||||
self.refresh_row(row_index)
|
||||
elif old == "column":
|
||||
_, column_index = self.cursor_coordinate
|
||||
self.refresh_column(column_index)
|
||||
|
||||
self._scroll_cursor_into_view()
|
||||
@@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._total_row_height + header_height,
|
||||
)
|
||||
|
||||
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
|
||||
def _get_cell_region(self, coordinate: Coordinate) -> Region:
|
||||
"""Get the region of the cell at the given spatial coordinate."""
|
||||
valid_row = 0 <= row_index < len(self.rows)
|
||||
valid_column = 0 <= column_index < len(self.columns)
|
||||
valid_cell = valid_row and valid_column
|
||||
if not valid_cell:
|
||||
if not self.is_valid_coordinate(coordinate):
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
row_index, column_index = coordinate
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
row = self.rows[row_key]
|
||||
|
||||
@@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
def _get_row_region(self, row_index: int) -> Region:
|
||||
"""Get the region of the row at the given index."""
|
||||
rows = self.rows
|
||||
valid_row = 0 <= row_index < len(rows)
|
||||
if not valid_row:
|
||||
if not self.is_valid_row_index(row_index):
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
rows = self.rows
|
||||
row_key = self._row_locations.get_key(row_index)
|
||||
row = rows[row_key]
|
||||
row_width = sum(column.render_width for column in self.columns.values())
|
||||
@@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
|
||||
def _get_column_region(self, column_index: int) -> Region:
|
||||
"""Get the region of the column at the given index."""
|
||||
columns = self.columns
|
||||
valid_column = 0 <= column_index < len(columns)
|
||||
if not valid_column:
|
||||
if not self.is_valid_column_index(column_index):
|
||||
return Region(0, 0, 0, 0)
|
||||
|
||||
columns = self.columns
|
||||
x = sum(column.render_width for column in self.ordered_columns[:column_index])
|
||||
column_key = self._column_locations.get_key(column_index)
|
||||
width = columns[column_key].render_width
|
||||
@@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
self._updated_cells.clear()
|
||||
self._update_column_widths(updated_columns)
|
||||
|
||||
def refresh_cell(self, row_index: int, column_index: int) -> None:
|
||||
"""Refresh a cell.
|
||||
def refresh_coordinate(self, coordinate: Coordinate) -> None:
|
||||
"""Refresh the cell at a coordinate.
|
||||
|
||||
Args:
|
||||
row_index: Row index.
|
||||
column_index: Column index.
|
||||
coordinate: The coordinate to refresh.
|
||||
"""
|
||||
if row_index < 0 or column_index < 0:
|
||||
if not self.is_valid_coordinate(coordinate):
|
||||
return
|
||||
region = self._get_cell_region(row_index, column_index)
|
||||
region = self._get_cell_region(coordinate)
|
||||
self._refresh_region(region)
|
||||
|
||||
def refresh_row(self, row_index: int) -> None:
|
||||
@@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
Args:
|
||||
row_index: The index of the row to refresh.
|
||||
"""
|
||||
if row_index < 0 or row_index >= len(self.rows):
|
||||
if not self.is_valid_row_index(row_index):
|
||||
return
|
||||
|
||||
region = self._get_row_region(row_index)
|
||||
@@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
Args:
|
||||
column_index: The index of the column to refresh.
|
||||
"""
|
||||
if column_index < 0 or column_index >= len(self.columns):
|
||||
if not self.is_valid_column_index(column_index):
|
||||
return
|
||||
|
||||
region = self._get_column_region(column_index)
|
||||
@@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
region = region.translate(-self.scroll_offset)
|
||||
self.refresh(region)
|
||||
|
||||
def is_valid_row_index(self, row_index: int) -> bool:
|
||||
"""Return a boolean indicating whether the row_index is within table bounds.
|
||||
|
||||
Args:
|
||||
row_index: The row index to check.
|
||||
|
||||
Returns:
|
||||
True if the row index is within the bounds of the table.
|
||||
"""
|
||||
return 0 <= row_index < len(self.rows)
|
||||
|
||||
def is_valid_column_index(self, column_index: int) -> bool:
|
||||
"""Return a boolean indicating whether the column_index is within table bounds.
|
||||
|
||||
Args:
|
||||
column_index: The column index to check.
|
||||
|
||||
Returns:
|
||||
True if the column index is within the bounds of the table.
|
||||
"""
|
||||
return 0 <= column_index < len(self.columns)
|
||||
|
||||
def is_valid_coordinate(self, coordinate: Coordinate) -> bool:
|
||||
"""Return a boolean indicating whether the given coordinate is within table bounds.
|
||||
|
||||
Args:
|
||||
coordinate: The coordinate to validate.
|
||||
|
||||
Returns:
|
||||
True if the coordinate is within the bounds of the table.
|
||||
"""
|
||||
row_index, column_index = coordinate
|
||||
return self.is_valid_row_index(row_index) and self.is_valid_column_index(
|
||||
column_index
|
||||
)
|
||||
|
||||
@property
|
||||
def ordered_columns(self) -> list[Column]:
|
||||
"""The list of Columns in the DataTable, ordered as they currently appear on screen."""
|
||||
@@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
x, y, width, height = self._get_column_region(self.cursor_column)
|
||||
region = Region(x, int(self.scroll_y) + top, width, height - top)
|
||||
else:
|
||||
region = self._get_cell_region(self.cursor_row, self.cursor_column)
|
||||
region = self._get_cell_region(self.cursor_coordinate)
|
||||
|
||||
self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
|
||||
|
||||
@@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
||||
elif cursor_type == "row":
|
||||
self.refresh_row(self.hover_row)
|
||||
elif cursor_type == "cell":
|
||||
self.refresh_cell(*self.hover_coordinate)
|
||||
self.refresh_coordinate(self.hover_coordinate)
|
||||
|
||||
def on_click(self, event: events.Click) -> None:
|
||||
self._set_hover_cursor(True)
|
||||
|
||||
@@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist():
|
||||
table.update_cell("INVALID", "CELL", "Value")
|
||||
|
||||
|
||||
# TODO: Test update coordinate
|
||||
async def test_update_coordinate_coordinate_exists():
|
||||
app = DataTableApp()
|
||||
async with app.run_test():
|
||||
table = app.query_one(DataTable)
|
||||
column_0, column_1 = table.add_columns("A", "B")
|
||||
row_0, *_ = table.add_rows(ROWS)
|
||||
table.update_coordinate(Coordinate(0, 1), "newvalue")
|
||||
assert table.get_cell_value(row_0, column_1) == "newvalue"
|
||||
|
||||
|
||||
async def test_update_coordinate_coordinate_doesnt_exist():
|
||||
app = DataTableApp()
|
||||
async with app.run_test():
|
||||
table = app.query_one(DataTable)
|
||||
table.add_columns("A", "B")
|
||||
table.add_rows(ROWS)
|
||||
with pytest.raises(CellDoesNotExist):
|
||||
table.update_coordinate(Coordinate(999, 999), "newvalue")
|
||||
|
||||
|
||||
def test_key_equals_equivalent_string():
|
||||
|
||||
Reference in New Issue
Block a user