Merge pull request #1786 from Textualize/datatable-private-data

DataTable - make `data` attribute private & expand APIs for reading data
This commit is contained in:
Will McGugan
2023-02-14 16:37:50 +00:00
committed by GitHub
4 changed files with 230 additions and 43 deletions

View File

@@ -23,7 +23,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Added `DataTable.get_cell` to retrieve a cell by column/row keys https://github.com/Textualize/textual/pull/1638
- Added `DataTable.get_cell_at` to retrieve a cell by coordinate https://github.com/Textualize/textual/pull/1638
- Added `DataTable.update_cell` to update a cell by column/row keys https://github.com/Textualize/textual/pull/1638
- Added `DataTable.update_cell_at`to update a cell at a coordinate https://github.com/Textualize/textual/pull/1638
- Added `DataTable.update_cell_at` to update a cell at a coordinate https://github.com/Textualize/textual/pull/1638
- Added `DataTable.ordered_rows` property to retrieve `Row`s as they're currently ordered https://github.com/Textualize/textual/pull/1638
- Added `DataTable.ordered_columns` property to retrieve `Column`s as they're currently ordered https://github.com/Textualize/textual/pull/1638
- Added `DataTable.coordinate_to_cell_key` to find the key for the cell at a coordinate https://github.com/Textualize/textual/pull/1638
@@ -31,6 +31,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Added `DataTable.is_valid_row_index` https://github.com/Textualize/textual/pull/1638
- Added `DataTable.is_valid_column_index` https://github.com/Textualize/textual/pull/1638
- Added attributes to events emitted from `DataTable` indicating row/column/cell keys https://github.com/Textualize/textual/pull/1638
- Added `DataTable.get_row` to retrieve the values from a row by key https://github.com/Textualize/textual/pull/1786
- Added `DataTable.get_row_at` to retrieve the values from a row by index https://github.com/Textualize/textual/pull/1786
- Added `DataTable.get_column` to retrieve the values from a column by key https://github.com/Textualize/textual/pull/1786
- Added `DataTable.get_column_at` to retrieve the values from a column by index https://github.com/Textualize/textual/pull/1786
- Added `DOMNode.watch` and `DOMNode.is_attached` methods https://github.com/Textualize/textual/pull/1750
- Added `DOMNode.css_tree` which is a renderable that shows the DOM and CSS https://github.com/Textualize/textual/pull/1778
- Added `DOMNode.children_view` which is a view on to a nodes children list, use for querying https://github.com/Textualize/textual/pull/1778
@@ -52,6 +56,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Breaking change: `DataTable.data` structure changed, and will be made private in upcoming release https://github.com/Textualize/textual/pull/1638
- Breaking change: `DataTable.refresh_cell` was renamed to `DataTable.refresh_coordinate` https://github.com/Textualize/textual/pull/1638
- Breaking change: `DataTable.get_row_height` now takes a `RowKey` argument instead of a row index https://github.com/Textualize/textual/pull/1638
- Breaking change: `DataTable.data` renamed to `DataTable._data` (it's now private) https://github.com/Textualize/textual/pull/1786
- The `_filter` module was made public (now called `filter`) https://github.com/Textualize/textual/pull/1638
- Breaking change: renamed `Checkbox` to `Switch` https://github.com/Textualize/textual/issues/1746
- `App.install_screen` name is no longer optional https://github.com/Textualize/textual/pull/1778

View File

@@ -47,6 +47,16 @@ class CellDoesNotExist(Exception):
do not exist in the DataTable."""
class RowDoesNotExist(Exception):
"""Raised when the user supplies a row index or row key which does
not exist in the DataTable (e.g. out of bounds index, invalid key)"""
class ColumnDoesNotExist(Exception):
"""Raised when the user supplies a column index or column key which does
not exist in the DataTable (e.g. out of bounds index, invalid key)"""
class DuplicateKey(Exception):
"""The key supplied already exists.
@@ -442,7 +452,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
classes: str | None = None,
) -> None:
super().__init__(name=name, id=id, classes=classes)
self.data: dict[RowKey, dict[ColumnKey, CellType]] = {}
self._data: dict[RowKey, dict[ColumnKey, CellType]] = {}
"""Contains the cells of the table, indexed by row key and column key.
The final positioning of a cell on screen cannot be determined solely by this
structure. Instead, we must check _row_locations and _column_locations to find
@@ -576,7 +586,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
column_key = ColumnKey(column_key)
try:
self.data[row_key][column_key] = value
self._data[row_key][column_key] = value
except KeyError:
raise CellDoesNotExist(
f"No cell exists for row_key={row_key!r}, column_key={column_key!r}."
@@ -607,14 +617,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_key, column_key = self.coordinate_to_cell_key(coordinate)
self.update_cell(row_key, column_key, value, update_width=update_width)
def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]:
"""For a given column key, return the cells in that column in the
order they currently appear on screen."""
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
row = self.data[row_key]
yield row[column_key]
def get_cell(self, row_key: RowKey, column_key: ColumnKey) -> CellType:
"""Given a row key and column key, return the value of the corresponding cell.
@@ -626,7 +628,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
The value of the cell identified by the row and column keys.
"""
try:
cell_value = self.data[row_key][column_key]
cell_value = self._data[row_key][column_key]
except KeyError:
raise CellDoesNotExist(
f"No cell exists for row_key={row_key!r}, column_key={column_key!r}."
@@ -648,6 +650,83 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_key, column_key = self.coordinate_to_cell_key(coordinate)
return self.get_cell(row_key, column_key)
def get_row(self, row_key: RowKey | str) -> list[CellType]:
"""Get the values from the row identified by the given row key.
Args:
row_key: The key of the row.
Returns:
A list of the values contained within the row.
Raises:
RowDoesNotExist: When there is no row corresponding to the key.
"""
if row_key not in self._row_locations:
raise RowDoesNotExist(f"Row key {row_key!r} is not valid.")
cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {})
ordered_row: list[CellType] = [
cell_mapping[column.key] for column in self.ordered_columns
]
return ordered_row
def get_row_at(self, row_index: int) -> list[CellType]:
"""Get the values from the cells in a row at a given index. This will
return the values from a row based on the rows _current position_ in
the table.
Args:
row_index: The index of the row.
Returns:
A list of the values contained in the row.
Raises:
RowDoesNotExist: If there is no row with the given index.
"""
if not self.is_valid_row_index(row_index):
raise RowDoesNotExist(f"Row index {row_index!r} is not valid.")
row_key = self._row_locations.get_key(row_index)
return self.get_row(row_key)
def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
"""Get the values from the column identified by the given column key.
Args:
column_key: The key of the column.
Returns:
A generator which yields the cells in the column.
Raises:
ColumnDoesNotExist: If there is no column corresponding to the key.
"""
if column_key not in self._column_locations:
raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")
data = self._data
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
yield data[row_key][column_key]
def get_column_at(self, column_index: int) -> Iterable[CellType]:
"""Get the values from the column at a given index.
Args:
column_index: The index of the column.
Returns:
A generator which yields the cells in the column.
Raises:
ColumnDoesNotExist: If there is no column with the given index.
"""
if not self.is_valid_column_index(column_index):
raise ColumnDoesNotExist(f"Column index {column_index!r} is not valid.")
column_key = self._column_locations.get_key(column_index)
yield from self.get_column(column_key)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
self._cell_render_cache.clear()
@@ -752,7 +831,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _highlight_row(self, row_index: int) -> None:
"""Apply highlighting to the row at the given index, and post event."""
self.refresh_row(row_index)
is_valid_row = row_index < len(self.data)
is_valid_row = row_index < len(self._data)
if is_valid_row:
row_key = self._row_locations.get_key(row_index)
self.post_message_no_wait(
@@ -819,12 +898,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
console = self.app.console
label_width = measure(console, column.label, 1)
content_width = column.content_width
cell_value = self.data[row_key][column_key]
cell_value = self._data[row_key][column_key]
new_content_width = measure(console, default_cell_formatter(cell_value), 1)
if new_content_width < content_width:
cells_in_column = self._get_cells_in_column(column_key)
cells_in_column = self.get_column(column_key)
cell_widths = [
measure(console, default_cell_formatter(cell), 1)
for cell in cells_in_column
@@ -910,7 +989,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
"""
self._clear_caches()
self._y_offsets.clear()
self.data.clear()
self._data.clear()
self.rows.clear()
if columns:
self.columns.clear()
@@ -991,7 +1070,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_index = self.row_count
# Map the key of this row to its current index
self._row_locations[row_key] = row_index
self.data[row_key] = {
self._data[row_key] = {
column.key: cell
for column, cell in zip_longest(self.ordered_columns, cells)
}
@@ -1190,15 +1269,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row: list[RenderableType] = [column.label for column in ordered_columns]
return row
# Ensure we order the cells in the row based on current column ordering
row_key = self._row_locations.get_key(row_index)
cell_mapping: dict[ColumnKey, CellType] = self.data.get(row_key, {})
ordered_row: list[CellType] = []
for column in ordered_columns:
cell = cell_mapping[column.key]
ordered_row.append(cell)
ordered_row = self.get_row_at(row_index)
empty = Text()
return [
Text() if datum is None else default_cell_formatter(datum) or empty
@@ -1527,7 +1598,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
return result
ordered_rows = sorted(
self.data.items(), key=sort_by_column_keys, reverse=reverse
self._data.items(), key=sort_by_column_keys, reverse=reverse
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}

View File

@@ -5,21 +5,25 @@ from ._data_table import (
CellKey,
CellType,
Column,
ColumnDoesNotExist,
ColumnKey,
CursorType,
DuplicateKey,
Row,
RowDoesNotExist,
RowKey,
)
__all__ = [
"Column",
"Row",
"RowKey",
"ColumnKey",
"CellKey",
"CursorType",
"CellType",
"CellDoesNotExist",
"CellKey",
"CellType",
"Column",
"ColumnDoesNotExist",
"ColumnKey",
"CursorType",
"DuplicateKey",
"Row",
"RowDoesNotExist",
"RowKey",
]

View File

@@ -11,8 +11,16 @@ from textual.events import Click, MouseMove
from textual.message import Message
from textual.message_pump import MessagePump
from textual.widgets import DataTable
from textual.widgets._data_table import DuplicateKey
from textual.widgets.data_table import CellDoesNotExist, CellKey, ColumnKey, Row, RowKey
from textual.widgets.data_table import (
CellDoesNotExist,
CellKey,
ColumnDoesNotExist,
ColumnKey,
DuplicateKey,
Row,
RowDoesNotExist,
RowKey,
)
ROWS = [["0/0", "0/1"], ["1/0", "1/1"], ["2/0", "2/1"]]
@@ -157,10 +165,10 @@ async def test_add_rows():
row_keys = table.add_rows(ROWS)
# We're given a key for each row
assert len(row_keys) == len(ROWS)
assert len(row_keys) == len(table.data)
assert len(row_keys) == len(table._data)
assert table.row_count == len(ROWS)
# Each key can be used to fetch a row from the DataTable
assert all(key in table.data for key in row_keys)
assert all(key in table._data for key in row_keys)
async def test_add_rows_user_defined_keys():
@@ -179,14 +187,14 @@ async def test_add_rows_user_defined_keys():
# Ensure the data in the table is mapped as expected
first_row = {key_a: ROWS[0][0], key_b: ROWS[0][1]}
assert table.data[algernon_key] == first_row
assert table.data["algernon"] == first_row
assert table._data[algernon_key] == first_row
assert table._data["algernon"] == first_row
second_row = {key_a: ROWS[1][0], key_b: ROWS[1][1]}
assert table.data["charlie"] == second_row
assert table._data["charlie"] == second_row
third_row = {key_a: ROWS[2][0], key_b: ROWS[2][1]}
assert table.data[auto_key] == third_row
assert table._data[auto_key] == third_row
first_row = Row(algernon_key, height=1)
assert table.rows[algernon_key] == first_row
@@ -260,7 +268,7 @@ async def test_clear():
assert table.hover_coordinate == Coordinate(0, 0)
# Ensure that the table has been cleared
assert table.data == {}
assert table._data == {}
assert table.rows == {}
assert table.row_count == 0
assert len(table.columns) == 1
@@ -347,6 +355,105 @@ async def test_get_cell_at_exception():
table.get_cell_at(Coordinate(9999, 0))
async def test_get_row():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
first_row = table.add_row(2, 4, 1)
second_row = table.add_row(3, 2, 1)
assert table.get_row(first_row) == [2, 4, 1]
assert table.get_row(second_row) == [3, 2, 1]
# Even if row positions change, keys should always refer to same rows.
table.sort(b)
assert table.get_row(first_row) == [2, 4, 1]
assert table.get_row(second_row) == [3, 2, 1]
async def test_get_row_invalid_row_key():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
with pytest.raises(RowDoesNotExist):
table.get_row("INVALID")
async def test_get_row_at():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(2, 4, 1)
table.add_row(3, 2, 1)
assert table.get_row_at(0) == [2, 4, 1]
assert table.get_row_at(1) == [3, 2, 1]
# If we sort, then the rows present at the indices *do* change!
table.sort(b)
# Since we sorted on column "B", the rows at indices 0 and 1 are swapped.
assert table.get_row_at(0) == [3, 2, 1]
assert table.get_row_at(1) == [2, 4, 1]
@pytest.mark.parametrize("index", (-1, 2))
async def test_get_row_at_invalid_index(index):
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
table.add_columns("A", "B", "C")
table.add_row(2, 4, 1)
table.add_row(3, 2, 1)
with pytest.raises(RowDoesNotExist):
table.get_row_at(index)
async def test_get_column():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b = table.add_columns("A", "B")
table.add_rows(ROWS)
cells = table.get_column(a)
assert next(cells) == ROWS[0][0]
assert next(cells) == ROWS[1][0]
assert next(cells) == ROWS[2][0]
with pytest.raises(StopIteration):
next(cells)
async def test_get_column_invalid_key():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
with pytest.raises(ColumnDoesNotExist):
list(table.get_column("INVALID"))
async def test_get_column_at():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
table.add_columns("A", "B")
table.add_rows(ROWS)
first_column = list(table.get_column_at(0))
assert first_column == [ROWS[0][0], ROWS[1][0], ROWS[2][0]]
second_column = list(table.get_column_at(1))
assert second_column == [ROWS[0][1], ROWS[1][1], ROWS[2][1]]
@pytest.mark.parametrize("index", [-1, 5])
async def test_get_column_at_invalid_index(index):
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
with pytest.raises(ColumnDoesNotExist):
list(table.get_column_at(index))
async def test_update_cell_cell_exists():
app = DataTableApp()
async with app.run_test():