mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Extract common coordinate validation logic into method in DataTable
This commit is contained in:
@@ -86,8 +86,8 @@ class ColumnKey(StringKey):
|
|||||||
|
|
||||||
|
|
||||||
class CellKey(NamedTuple):
|
class CellKey(NamedTuple):
|
||||||
row_key: RowKey
|
row_key: RowKey | str
|
||||||
column_key: ColumnKey
|
column_key: ColumnKey | str
|
||||||
|
|
||||||
|
|
||||||
def default_cell_formatter(obj: object) -> RenderableType | None:
|
def default_cell_formatter(obj: object) -> RenderableType | None:
|
||||||
@@ -554,7 +554,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
update_width: Whether to resize the column width to accommodate
|
update_width: Whether to resize the column width to accommodate
|
||||||
for the new cell content.
|
for the new cell content.
|
||||||
"""
|
"""
|
||||||
# TODO: Validate coordinate and raise exception
|
if not self.is_valid_coordinate(coordinate):
|
||||||
|
raise CellDoesNotExist()
|
||||||
|
|
||||||
row_key, column_key = self.coordinate_to_cell_key(coordinate)
|
row_key, column_key = self.coordinate_to_cell_key(coordinate)
|
||||||
self.update_cell(row_key, column_key, value, update_width=update_width)
|
self.update_cell(row_key, column_key, value, update_width=update_width)
|
||||||
|
|
||||||
@@ -645,8 +647,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
self._clear_caches()
|
self._clear_caches()
|
||||||
|
|
||||||
def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
|
def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
|
||||||
self.refresh_cell(*old)
|
self.refresh_coordinate(old)
|
||||||
self.refresh_cell(*value)
|
self.refresh_coordinate(value)
|
||||||
|
|
||||||
def watch_cursor_coordinate(
|
def watch_cursor_coordinate(
|
||||||
self, old_coordinate: Coordinate, new_coordinate: Coordinate
|
self, old_coordinate: Coordinate, new_coordinate: Coordinate
|
||||||
@@ -655,7 +657,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
# Refresh the old and the new cell, and emit the appropriate
|
# Refresh the old and the new cell, and emit the appropriate
|
||||||
# message to tell users of the newly highlighted row/cell/column.
|
# message to tell users of the newly highlighted row/cell/column.
|
||||||
if self.cursor_type == "cell":
|
if self.cursor_type == "cell":
|
||||||
self.refresh_cell(*old_coordinate)
|
self.refresh_coordinate(old_coordinate)
|
||||||
self._highlight_coordinate(new_coordinate)
|
self._highlight_coordinate(new_coordinate)
|
||||||
elif self.cursor_type == "row":
|
elif self.cursor_type == "row":
|
||||||
self.refresh_row(old_coordinate.row)
|
self.refresh_row(old_coordinate.row)
|
||||||
@@ -666,7 +668,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
|
|
||||||
def _highlight_coordinate(self, coordinate: Coordinate) -> None:
|
def _highlight_coordinate(self, coordinate: Coordinate) -> None:
|
||||||
"""Apply highlighting to the cell at the coordinate, and emit event."""
|
"""Apply highlighting to the cell at the coordinate, and emit event."""
|
||||||
self.refresh_cell(*coordinate)
|
self.refresh_coordinate(coordinate)
|
||||||
try:
|
try:
|
||||||
cell_value = self.get_value_at(coordinate)
|
cell_value = self.get_value_at(coordinate)
|
||||||
except CellDoesNotExist:
|
except CellDoesNotExist:
|
||||||
@@ -690,6 +692,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
Returns:
|
Returns:
|
||||||
The key of the cell currently occupying this coordinate.
|
The key of the cell currently occupying this coordinate.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_valid_coordinate(coordinate):
|
||||||
|
raise CellDoesNotExist()
|
||||||
row_index, column_index = coordinate
|
row_index, column_index = coordinate
|
||||||
row_key = self._row_locations.get_key(row_index)
|
row_key = self._row_locations.get_key(row_index)
|
||||||
column_key = self._column_locations.get_key(column_index)
|
column_key = self._column_locations.get_key(column_index)
|
||||||
@@ -729,12 +733,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
|
|
||||||
# Refresh cells that were previously impacted by the cursor
|
# Refresh cells that were previously impacted by the cursor
|
||||||
# but may no longer be.
|
# but may no longer be.
|
||||||
row_index, column_index = self.cursor_coordinate
|
|
||||||
if old == "cell":
|
if old == "cell":
|
||||||
self.refresh_cell(row_index, column_index)
|
self.refresh_coordinate(self.cursor_coordinate)
|
||||||
elif old == "row":
|
elif old == "row":
|
||||||
|
row_index, _ = self.cursor_coordinate
|
||||||
self.refresh_row(row_index)
|
self.refresh_row(row_index)
|
||||||
elif old == "column":
|
elif old == "column":
|
||||||
|
_, column_index = self.cursor_coordinate
|
||||||
self.refresh_column(column_index)
|
self.refresh_column(column_index)
|
||||||
|
|
||||||
self._scroll_cursor_into_view()
|
self._scroll_cursor_into_view()
|
||||||
@@ -790,14 +795,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
self._total_row_height + header_height,
|
self._total_row_height + header_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_cell_region(self, row_index: int, column_index: int) -> Region:
|
def _get_cell_region(self, coordinate: Coordinate) -> Region:
|
||||||
"""Get the region of the cell at the given spatial coordinate."""
|
"""Get the region of the cell at the given spatial coordinate."""
|
||||||
valid_row = 0 <= row_index < len(self.rows)
|
if not self.is_valid_coordinate(coordinate):
|
||||||
valid_column = 0 <= column_index < len(self.columns)
|
|
||||||
valid_cell = valid_row and valid_column
|
|
||||||
if not valid_cell:
|
|
||||||
return Region(0, 0, 0, 0)
|
return Region(0, 0, 0, 0)
|
||||||
|
|
||||||
|
row_index, column_index = coordinate
|
||||||
row_key = self._row_locations.get_key(row_index)
|
row_key = self._row_locations.get_key(row_index)
|
||||||
row = self.rows[row_key]
|
row = self.rows[row_key]
|
||||||
|
|
||||||
@@ -814,11 +817,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
|
|
||||||
def _get_row_region(self, row_index: int) -> Region:
|
def _get_row_region(self, row_index: int) -> Region:
|
||||||
"""Get the region of the row at the given index."""
|
"""Get the region of the row at the given index."""
|
||||||
rows = self.rows
|
if not self.is_valid_row_index(row_index):
|
||||||
valid_row = 0 <= row_index < len(rows)
|
|
||||||
if not valid_row:
|
|
||||||
return Region(0, 0, 0, 0)
|
return Region(0, 0, 0, 0)
|
||||||
|
|
||||||
|
rows = self.rows
|
||||||
row_key = self._row_locations.get_key(row_index)
|
row_key = self._row_locations.get_key(row_index)
|
||||||
row = rows[row_key]
|
row = rows[row_key]
|
||||||
row_width = sum(column.render_width for column in self.columns.values())
|
row_width = sum(column.render_width for column in self.columns.values())
|
||||||
@@ -830,11 +832,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
|
|
||||||
def _get_column_region(self, column_index: int) -> Region:
|
def _get_column_region(self, column_index: int) -> Region:
|
||||||
"""Get the region of the column at the given index."""
|
"""Get the region of the column at the given index."""
|
||||||
columns = self.columns
|
if not self.is_valid_column_index(column_index):
|
||||||
valid_column = 0 <= column_index < len(columns)
|
|
||||||
if not valid_column:
|
|
||||||
return Region(0, 0, 0, 0)
|
return Region(0, 0, 0, 0)
|
||||||
|
|
||||||
|
columns = self.columns
|
||||||
x = sum(column.render_width for column in self.ordered_columns[:column_index])
|
x = sum(column.render_width for column in self.ordered_columns[:column_index])
|
||||||
column_key = self._column_locations.get_key(column_index)
|
column_key = self._column_locations.get_key(column_index)
|
||||||
width = columns[column_key].render_width
|
width = columns[column_key].render_width
|
||||||
@@ -1001,16 +1002,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
self._updated_cells.clear()
|
self._updated_cells.clear()
|
||||||
self._update_column_widths(updated_columns)
|
self._update_column_widths(updated_columns)
|
||||||
|
|
||||||
def refresh_cell(self, row_index: int, column_index: int) -> None:
|
def refresh_coordinate(self, coordinate: Coordinate) -> None:
|
||||||
"""Refresh a cell.
|
"""Refresh the cell at a coordinate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
row_index: Row index.
|
coordinate: The coordinate to refresh.
|
||||||
column_index: Column index.
|
|
||||||
"""
|
"""
|
||||||
if row_index < 0 or column_index < 0:
|
if not self.is_valid_coordinate(coordinate):
|
||||||
return
|
return
|
||||||
region = self._get_cell_region(row_index, column_index)
|
region = self._get_cell_region(coordinate)
|
||||||
self._refresh_region(region)
|
self._refresh_region(region)
|
||||||
|
|
||||||
def refresh_row(self, row_index: int) -> None:
|
def refresh_row(self, row_index: int) -> None:
|
||||||
@@ -1019,7 +1019,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
Args:
|
Args:
|
||||||
row_index: The index of the row to refresh.
|
row_index: The index of the row to refresh.
|
||||||
"""
|
"""
|
||||||
if row_index < 0 or row_index >= len(self.rows):
|
if not self.is_valid_row_index(row_index):
|
||||||
return
|
return
|
||||||
|
|
||||||
region = self._get_row_region(row_index)
|
region = self._get_row_region(row_index)
|
||||||
@@ -1031,7 +1031,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
Args:
|
Args:
|
||||||
column_index: The index of the column to refresh.
|
column_index: The index of the column to refresh.
|
||||||
"""
|
"""
|
||||||
if column_index < 0 or column_index >= len(self.columns):
|
if not self.is_valid_column_index(column_index):
|
||||||
return
|
return
|
||||||
|
|
||||||
region = self._get_column_region(column_index)
|
region = self._get_column_region(column_index)
|
||||||
@@ -1046,6 +1046,42 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
region = region.translate(-self.scroll_offset)
|
region = region.translate(-self.scroll_offset)
|
||||||
self.refresh(region)
|
self.refresh(region)
|
||||||
|
|
||||||
|
def is_valid_row_index(self, row_index: int) -> bool:
|
||||||
|
"""Return a boolean indicating whether the row_index is within table bounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row_index: The row index to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the row index is within the bounds of the table.
|
||||||
|
"""
|
||||||
|
return 0 <= row_index < len(self.rows)
|
||||||
|
|
||||||
|
def is_valid_column_index(self, column_index: int) -> bool:
|
||||||
|
"""Return a boolean indicating whether the column_index is within table bounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_index: The column index to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the column index is within the bounds of the table.
|
||||||
|
"""
|
||||||
|
return 0 <= column_index < len(self.columns)
|
||||||
|
|
||||||
|
def is_valid_coordinate(self, coordinate: Coordinate) -> bool:
|
||||||
|
"""Return a boolean indicating whether the given coordinate is within table bounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinate: The coordinate to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the coordinate is within the bounds of the table.
|
||||||
|
"""
|
||||||
|
row_index, column_index = coordinate
|
||||||
|
return self.is_valid_row_index(row_index) and self.is_valid_column_index(
|
||||||
|
column_index
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ordered_columns(self) -> list[Column]:
|
def ordered_columns(self) -> list[Column]:
|
||||||
"""The list of Columns in the DataTable, ordered as they currently appear on screen."""
|
"""The list of Columns in the DataTable, ordered as they currently appear on screen."""
|
||||||
@@ -1443,7 +1479,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
x, y, width, height = self._get_column_region(self.cursor_column)
|
x, y, width, height = self._get_column_region(self.cursor_column)
|
||||||
region = Region(x, int(self.scroll_y) + top, width, height - top)
|
region = Region(x, int(self.scroll_y) + top, width, height - top)
|
||||||
else:
|
else:
|
||||||
region = self._get_cell_region(self.cursor_row, self.cursor_column)
|
region = self._get_cell_region(self.cursor_coordinate)
|
||||||
|
|
||||||
self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
|
self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
|
||||||
|
|
||||||
@@ -1463,7 +1499,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
elif cursor_type == "row":
|
elif cursor_type == "row":
|
||||||
self.refresh_row(self.hover_row)
|
self.refresh_row(self.hover_row)
|
||||||
elif cursor_type == "cell":
|
elif cursor_type == "cell":
|
||||||
self.refresh_cell(*self.hover_coordinate)
|
self.refresh_coordinate(self.hover_coordinate)
|
||||||
|
|
||||||
def on_click(self, event: events.Click) -> None:
|
def on_click(self, event: events.Click) -> None:
|
||||||
self._set_hover_cursor(True)
|
self._set_hover_cursor(True)
|
||||||
|
|||||||
@@ -313,7 +313,24 @@ async def test_update_cell_cell_doesnt_exist():
|
|||||||
table.update_cell("INVALID", "CELL", "Value")
|
table.update_cell("INVALID", "CELL", "Value")
|
||||||
|
|
||||||
|
|
||||||
# TODO: Test update coordinate
|
async def test_update_coordinate_coordinate_exists():
|
||||||
|
app = DataTableApp()
|
||||||
|
async with app.run_test():
|
||||||
|
table = app.query_one(DataTable)
|
||||||
|
column_0, column_1 = table.add_columns("A", "B")
|
||||||
|
row_0, *_ = table.add_rows(ROWS)
|
||||||
|
table.update_coordinate(Coordinate(0, 1), "newvalue")
|
||||||
|
assert table.get_cell_value(row_0, column_1) == "newvalue"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_coordinate_coordinate_doesnt_exist():
|
||||||
|
app = DataTableApp()
|
||||||
|
async with app.run_test():
|
||||||
|
table = app.query_one(DataTable)
|
||||||
|
table.add_columns("A", "B")
|
||||||
|
table.add_rows(ROWS)
|
||||||
|
with pytest.raises(CellDoesNotExist):
|
||||||
|
table.update_coordinate(Coordinate(999, 999), "newvalue")
|
||||||
|
|
||||||
|
|
||||||
def test_key_equals_equivalent_string():
|
def test_key_equals_equivalent_string():
|
||||||
|
|||||||
Reference in New Issue
Block a user