Add methods for retrieving column values to DataTable

This commit is contained in:
Darren Burns
2023-02-14 13:06:48 +00:00
parent fcdff48f0a
commit 16c9f15ab5

View File

@@ -603,14 +603,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_key, column_key = self.coordinate_to_cell_key(coordinate)
self.update_cell(row_key, column_key, value, update_width=update_width)
def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]:
"""For a given column key, return the cells in that column in the
order they currently appear on screen."""
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
row = self._data[row_key]
yield row[column_key]
def get_cell(self, row_key: RowKey, column_key: ColumnKey) -> CellType:
"""Given a row key and column key, return the value of the corresponding cell.
@@ -684,31 +676,43 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_key = self._row_locations.get_key(row_index)
return self.get_row(row_key)
def get_column(self, column_key: ColumnKey | str) -> list[CellType]:
def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
"""Get the values from the column identified by the given column key.
Args:
column_key: The key of the column.
Returns:
A list of values in the column
A generator which yields the cells in the column.
Raises:
ColumnDoesNotExist: If there is no column corresponding to the key.
"""
if column_key not in self._column_locations:
raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")
def get_column_at(self, column_index: int) -> list[CellType]:
for row_metadata in self.ordered_rows:
row_key = row_metadata.key
row = self._data[row_key]
yield row[column_key]
def get_column_at(self, column_index: int) -> Iterable[CellType]:
"""Get the values from the column at a given index.
Args:
column_index: The index of the column.
Returns:
A list containing the values in the column.
A generator which yields the cells in the column.
Raises:
ColumnDoesNotExist: If there is no column with the given index.
"""
if not self.is_valid_column_index(column_index):
raise ColumnDoesNotExist(f"Column index {column_index!r} is not valid.")
column_key = self._column_locations.get_key(column_index)
yield from self.get_column(column_key)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
@@ -886,7 +890,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
new_content_width = measure(console, default_cell_formatter(cell_value), 1)
if new_content_width < content_width:
cells_in_column = self._get_cells_in_column(column_key)
cells_in_column = self.get_column(column_key)
cell_widths = [
measure(console, default_cell_formatter(cell), 1)
for cell in cells_in_column