From 7ebc95fb542c5dcd5c07e926ef4b17cf0ba30371 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Thu, 9 Feb 2023 11:16:14 +0000 Subject: [PATCH] Updating tests for DataTable --- src/textual/widgets/_data_table.py | 5 +-- tests/test_data_table.py | 2 +- tests/test_two_way_dict.py | 51 ++++++++++++------------------ 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 775f8d3fd..c6424d92a 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -12,6 +12,7 @@ from typing import ( cast, NamedTuple, Any, + Sequence, ) import rich.repr @@ -1138,12 +1139,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """ ordered_columns = self.ordered_columns if row_index == -1: - row = [column.label for column in ordered_columns] + row: list[RenderableType] = [column.label for column in ordered_columns] return row # Ensure we order the cells in the row based on current column ordering row_key = self._row_locations.get_key(row_index) - cell_mapping: dict[ColumnKey, CellType] = self.data.get(row_key) + cell_mapping: dict[ColumnKey, CellType] = self.data.get(row_key, {}) ordered_row: list[CellType] = [] for column in ordered_columns: diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 5ae669407..53c4aa779 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -261,7 +261,7 @@ async def test_column_labels() -> None: async with app.run_test(): table = app.query_one(DataTable) table.add_columns("1", "2", "3") - actual_labels = [col.label for col in table.columns.values()] + actual_labels = [col.label.plain for col in table.columns.values()] expected_labels = ["1", "2", "3"] assert actual_labels == expected_labels diff --git a/tests/test_two_way_dict.py b/tests/test_two_way_dict.py index 26e1cb58e..9178f6fdb 100644 --- a/tests/test_two_way_dict.py +++ b/tests/test_two_way_dict.py @@ -4,7 +4,7 @@ from textual._two_way_dict import TwoWayDict @pytest.fixture -def map(): +def two_way_dict(): return TwoWayDict( { 1: 10, @@ -14,43 +14,32 @@ def map(): ) -def test_get(map): - assert map.get(1) == 10 +def test_get(two_way_dict): + assert two_way_dict.get(1) == 10 -def test_get_default_none(map): - assert map.get(9999) is None +def test_get_key(two_way_dict): + assert two_way_dict.get_key(30) == 3 -def test_get_default_supplied(map): - assert map.get(9999, -123) == -123 +def test_set_item(two_way_dict): + two_way_dict[40] = 400 + assert two_way_dict.get(40) == 400 + assert two_way_dict.get_key(400) == 40 -def test_get_key(map): - assert map.get_key(30) == 3 +def test_len(two_way_dict): + assert len(two_way_dict) == 3 -def test_get_key_default_none(map): - assert map.get_key(9999) is None +def test_delitem(two_way_dict): + assert two_way_dict.get(3) == 30 + assert two_way_dict.get_key(30) == 3 + del two_way_dict[3] + assert two_way_dict.get(3) is None + assert two_way_dict.get_key(30) is None -def test_get_key_default_supplied(map): - assert map.get_key(9999, -123) == -123 - - -def test_set_item(map): - map[40] = 400 - assert map.get(40) == 400 - assert map.get_key(400) == 40 - - -def test_len(map): - assert len(map) == 3 - - -def test_delitem(map): - assert map.get(3) == 30 - assert map.get_key(30) == 3 - del map[3] - assert map.get(3) is None - assert map.get_key(30) is None +def test_contains(two_way_dict): + assert 1 in two_way_dict + assert 10 not in two_way_dict