diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 12d1aadf1..7e9d86416 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -217,7 +217,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.columns: list[Column] = [] self.rows: dict[RowKey, Row] = {} self.data: dict[RowKey, list[CellType]] = {} - self.row_count = 0 # Keep tracking of key -> index for rows/cols. # For a given key, what is the current location of the corresponding row/col? @@ -271,6 +270,10 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def cursor_column(self) -> int: return self.cursor_cell.column + @property + def row_count(self) -> int: + return len(self.rows) + def get_cell_value(self, coordinate: Coordinate) -> CellType: """Get the value from the cell at the given coordinate. @@ -485,7 +488,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: columns: Also clear the columns. Defaults to False. """ - self.row_count = 0 self._clear_caches() self._y_offsets.clear() self.data.clear() @@ -568,7 +570,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): for line_no in range(height): self._y_offsets.append((row_key, line_no)) - self.row_count += 1 self._line_no += height self._new_rows.add(row_index) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 16567785b..29413497c 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,10 +1,11 @@ import pytest +from rich.text import Text from textual.app import App from textual.coordinate import Coordinate from textual.message import Message from textual.widgets import DataTable -from textual.widgets._data_table import StringKey, CellDoesNotExist +from textual.widgets._data_table import StringKey, CellDoesNotExist, RowKey, Row ROWS = [["0/0", "0/1"], ["1/0", "1/1"], ["2/0", "2/1"]] @@ -126,6 +127,44 @@ async def test_datatable_message_emission(): assert messages == expected_messages +async def test_add_rows(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + row_keys = table.add_rows(ROWS) + # We're given a key for each row + assert len(row_keys) == len(ROWS) + assert len(row_keys) == len(table.data) + assert table.row_count == len(ROWS) + # Each key can be used to fetch a row from the DataTable + assert all(key in table.data for key in row_keys) + # Ensure the keys are returned *in order*, and there's one for each row + for key, row in zip(row_keys, range(len(ROWS))): + assert table.rows[key].index == row + + +async def test_add_data_user_defined_keys(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + algernon_key = table.add_row(*ROWS[0], key="algernon") + table.add_row(*ROWS[1], key="charlie") + auto_key = table.add_row(*ROWS[2]) + + assert algernon_key == "algernon" + # We get a RowKey object back, but we can use our own string *or* this object + # to find the row we're looking for, they're considered equivalent for lookups. + assert isinstance(algernon_key, RowKey) + assert table.data[algernon_key] == ROWS[0] + assert table.data["algernon"] == ROWS[0] + assert table.data["charlie"] == ROWS[1] + assert table.data[auto_key] == ROWS[2] + + first_row = Row(algernon_key, index=0, height=1, y=0) + assert table.rows[algernon_key] == first_row + assert table.rows["algernon"] == first_row + + async def test_clear(): app = DataTableApp() async with app.run_test(): @@ -147,6 +186,7 @@ async def test_clear(): # Ensure that the table has been cleared assert table.data == {} assert table.rows == {} + assert table.row_count == 0 assert len(table.columns) == 1 # Clearing the columns too @@ -154,13 +194,31 @@ async def test_clear(): assert len(table.columns) == 0 -def test_add_rows_generates_keys(): - table = DataTable() - keys = table.add_rows(ROWS) +async def test_column_labels() -> None: + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("1", "2", "3") + assert [col.label for col in table.columns] == [Text("1"), Text("2"), Text("3")] - # Ensure the keys are returned in order, and there's one for each row - for key, row in zip(keys, range(len(ROWS))): - assert table.rows[key].index == row + +async def test_row_widths() -> None: + app = DataTableApp() + async with app.run_test() as pilot: + table = app.query_one(DataTable) + table.add_columns("foo", "bar") + + assert table.columns[0].width == 3 + assert table.columns[1].width == 3 + table.add_row("Hello", "World!") + await pilot.pause() + assert table.columns[0].content_width == 5 + assert table.columns[1].content_width == 6 + + table.add_row("Hello World!!!", "fo") + await pilot.pause() + assert table.columns[0].content_width == 14 + assert table.columns[1].content_width == 6 def test_get_cell_value_returns_value_at_cell(): diff --git a/tests/test_table.py b/tests/test_table.py deleted file mode 100644 index d57be5a60..000000000 --- a/tests/test_table.py +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio - -from rich.text import Text - -from textual.app import App, ComposeResult -from textual.widgets import DataTable - - -class TableApp(App): - def compose(self) -> ComposeResult: - yield DataTable() - - -async def test_table_clear() -> None: - """Check DataTable.clear""" - - app = TableApp() - async with app.run_test() as pilot: - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - assert table.row_count == 0 - row_key = table.add_row("Hello", "World!") - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {row_key: ["Hello", "World!"]} - assert table.row_count == 1 - table.clear() - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {} - assert table.row_count == 0 - - -async def test_table_clear_with_columns() -> None: - """Check DataTable.clear(columns=True)""" - - app = TableApp() - async with app.run_test() as pilot: - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - assert table.row_count == 0 - row_key = table.add_row("Hello", "World!") - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {row_key: ["Hello", "World!"]} - assert table.row_count == 1 - table.clear(columns=True) - assert [col.label for col in table.columns] == [] - assert table.data == {} - assert table.row_count == 0 - - -async def test_table_add_row() -> None: - - app = TableApp() - async with app.run_test(): - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - - assert table.columns[0].width == 3 - assert table.columns[1].width == 3 - table.add_row("Hello", "World!") - await asyncio.sleep(0) - assert table.columns[0].content_width == 5 - assert table.columns[1].content_width == 6 - - table.add_row("Hello World!!!", "fo") - await asyncio.sleep(0) - assert table.columns[0].content_width == 14 - assert table.columns[1].content_width == 6