From fcdff48f0a7198dd269cc7834aedf1fb6c78fc9e Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Tue, 14 Feb 2023 12:47:49 +0000 Subject: [PATCH] Testing invalid index and keys in DataTable.get_row* --- src/textual/widgets/_data_table.py | 8 ++++++-- src/textual/widgets/data_table.py | 18 ++++++++++------- tests/test_data_table.py | 32 ++++++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 7ce7df10d..086f15eb7 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -644,7 +644,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_key, column_key = self.coordinate_to_cell_key(coordinate) return self.get_cell(row_key, column_key) - def get_row(self, row_key: RowKey) -> list[CellType]: + def get_row(self, row_key: RowKey | str) -> list[CellType]: """Get the values from the row identified by the given row key. Args: @@ -656,6 +656,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Raises: RowDoesNotExist: When there is no row corresponding to the key. """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {}) ordered_row: list[CellType] = [] for column in self.ordered_columns: @@ -677,10 +679,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Raises: RowDoesNotExist: If there is no row with the given index. """ + if not self.is_valid_row_index(row_index): + raise RowDoesNotExist(f"Row index {row_index!r} is not valid.") row_key = self._row_locations.get_key(row_index) return self.get_row(row_key) - def get_column(self, column_key: ColumnKey) -> list[CellType]: + def get_column(self, column_key: ColumnKey | str) -> list[CellType]: """Get the values from the column identified by the given column key. Args: diff --git a/src/textual/widgets/data_table.py b/src/textual/widgets/data_table.py index a923e6b86..0bb18f87f 100644 --- a/src/textual/widgets/data_table.py +++ b/src/textual/widgets/data_table.py @@ -5,21 +5,25 @@ from ._data_table import ( CellKey, CellType, Column, + ColumnDoesNotExist, ColumnKey, CursorType, DuplicateKey, Row, + RowDoesNotExist, RowKey, ) __all__ = [ - "Column", - "Row", - "RowKey", - "ColumnKey", - "CellKey", - "CursorType", - "CellType", "CellDoesNotExist", + "CellKey", + "CellType", + "Column", + "ColumnDoesNotExist", + "ColumnKey", + "CursorType", "DuplicateKey", + "Row", + "RowDoesNotExist", + "RowKey", ] diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 37264a455..f83c27b1b 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -11,8 +11,16 @@ from textual.events import Click, MouseMove from textual.message import Message from textual.message_pump import MessagePump from textual.widgets import DataTable -from textual.widgets._data_table import DuplicateKey -from textual.widgets.data_table import CellDoesNotExist, CellKey, ColumnKey, Row, RowKey +from textual.widgets.data_table import ( + CellDoesNotExist, + CellKey, + ColumnDoesNotExist, + ColumnKey, + DuplicateKey, + Row, + RowDoesNotExist, + RowKey, +) ROWS = [["0/0", "0/1"], ["1/0", "1/1"], ["2/0", "2/1"]] @@ -363,6 +371,14 @@ async def test_get_row(): assert table.get_row(second_row) == [3, 2, 1] +async def test_get_row_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(RowDoesNotExist): + table.get_row("abc") + + async def test_get_row_at(): app = DataTableApp() async with app.run_test(): @@ -381,6 +397,18 @@ async def test_get_row_at(): assert table.get_row_at(1) == [2, 4, 1] +@pytest.mark.parametrize("index", (-1, 2)) +async def test_get_row_at_invalid_index(index): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B", "C") + table.add_row(2, 4, 1) + table.add_row(3, 2, 1) + with pytest.raises(RowDoesNotExist): + table.get_row_at(index) + + async def test_update_cell_cell_exists(): app = DataTableApp() async with app.run_test():