diff --git a/src/textual/render.py b/src/textual/render.py index 8911c4263..c1003b062 100644 --- a/src/textual/render.py +++ b/src/textual/render.py @@ -1,5 +1,6 @@ from __future__ import annotations +from rich.cells import cell_len from rich.console import Console, RenderableType from rich.protocol import rich_cast @@ -22,6 +23,9 @@ def measure( Returns: Width in cells """ + if isinstance(renderable, str): + return cell_len(renderable) + width = default renderable = rich_cast(renderable) get_console_width = getattr(renderable, "__rich_measure__", None) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index d209b2254..5ea9f0a59 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -13,8 +13,6 @@ from typing import ( NamedTuple, Callable, Sequence, - Type, - Optional, ) import rich.repr @@ -48,7 +46,6 @@ RowCacheKey: TypeAlias = ( "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int]" ) CursorType = Literal["cell", "row", "column", "none"] -CELL: CursorType = "cell" CellType = TypeVar("CellType") @@ -88,8 +85,8 @@ class ColumnKey(StringKey): class CellKey(NamedTuple): - row_key: RowKey - column_key: ColumnKey + row_key: RowKey | str + column_key: ColumnKey | str def default_cell_formatter(obj: object) -> RenderableType | None: @@ -238,12 +235,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): zebra_stripes = Reactive(False) header_height = Reactive(1) show_cursor = Reactive(True) - cursor_type = Reactive(CELL) + cursor_type = Reactive("cell") - cursor_cell: Reactive[Coordinate] = Reactive( + cursor_coordinate: Reactive[Coordinate] = Reactive( Coordinate(0, 0), repaint=False, always_update=True ) - hover_cell: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) + hover_coordinate: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) class CellHighlighted(Message, bubble=True): """Emitted when the cursor moves to highlight a new cell. @@ -256,19 +253,26 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: value: The value in the highlighted cell. coordinate: The coordinate of the highlighted cell. + cell_key: The key for the highlighted cell. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value self.coordinate: Coordinate = coordinate + self.cell_key: CellKey = cell_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class CellSelected(Message, bubble=True): """Emitted by the `DataTable` widget when a cell is selected. @@ -279,19 +283,26 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: value: The value in the cell that was selected. coordinate: The coordinate of the cell that was selected. + cell_key: The key for the selected cell. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value self.coordinate: Coordinate = coordinate + self.cell_key: CellKey = cell_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class RowHighlighted(Message, bubble=True): """Emitted when a row is highlighted. This message is only emitted when the @@ -300,15 +311,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_row: The y-coordinate of the cursor that highlighted the row. + row_key: The key of the row that was highlighted. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + self.row_key: RowKey = row_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class RowSelected(Message, bubble=True): """Emitted when a row is selected. This message is only emitted when the @@ -318,15 +332,18 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_row: The y-coordinate of the cursor that made the selection. + row_key: The key of the row that was selected. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + self.row_key: RowKey = row_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class ColumnHighlighted(Message, bubble=True): """Emitted when a column is highlighted. This message is only emitted when the @@ -336,15 +353,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_column: The x-coordinate of the column that was highlighted. + column_key: The key of the column that was highlighted. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + self.column_key = column_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key class ColumnSelected(Message, bubble=True): """Emitted when a column is selected. This message is only emitted when the @@ -354,15 +376,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Attributes: cursor_column: The x-coordinate of the column that was selected. + column_key: The key of the column that was selected. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + self.column_key = column_key super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key def __init__( self, @@ -379,69 +406,105 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): ) -> None: super().__init__(name=name, id=id, classes=classes) self.data: dict[RowKey, dict[ColumnKey, CellType]] = {} + """Contains the cells of the table, indexed by row key and column key. + The final positioning of a cell on screen cannot be determined solely by this + structure. Instead, we must check _row_locations and _column_locations to find + where each cell currently resides in space.""" - # Metadata on rows and columns in the table self.columns: dict[ColumnKey, Column] = {} + """Metadata about the columns of the table, indexed by their key.""" self.rows: dict[RowKey, Row] = {} + """Metadata about the rows of the table, indexed by their key.""" # Keep tracking of key -> index for rows/cols. These allow us to retrieve, # given a row or column key, the index that row or column is currently present at, - # and mean that rows and columns are location independent - they can move around. - self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + # and mean that rows and columns are location independent - they can move around + # without requiring us to modify the underlying data. self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) + """Maps row keys to row indices which represent row order.""" + self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + """Maps column keys to column indices which represent column order.""" self._row_render_cache: LRUCache[ RowCacheKey, tuple[SegmentLines, SegmentLines] ] = LRUCache(1000) + """For each row (a row can have a height of multiple lines), we maintain a cache + of the fixed and scrollable lines within that row to minimise how often we need to + re-render it.""" self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) + """Cache for individual cells.""" self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) + """Cache for lines within rows.""" - self._line_no = 0 self._require_update_dimensions: bool = False - + """Set to re-calculate dimensions on idle.""" self._new_rows: set[RowKey] = set() - self._updated_columns: set[ColumnKey] = set() - """Track which cells were updated, so that we can refresh them once on idle""" + """Tracking newly added rows to be used in re-calculation of dimensions on idle.""" + self._updated_cells: set[CellKey] = set() + """Track which cells were updated, so that we can refresh them once on idle.""" self.show_header = show_header - self.fixed_rows = fixed_rows - self.fixed_columns = fixed_columns - self.zebra_stripes = zebra_stripes + """Show/hide the header row (the row of column labels).""" self.header_height = header_height + """The height of the header row (the row of column labels).""" + self.fixed_rows = fixed_rows + """The number of rows to fix (prevented from scrolling).""" + self.fixed_columns = fixed_columns + """The number of columns to fix (prevented from scrolling).""" + self.zebra_stripes = zebra_stripes + """Apply zebra-stripe effect on row backgrounds (light, dark, light, dark, ...).""" self.show_cursor = show_cursor + """Show/hide both the keyboard and hover cursor.""" self._show_hover_cursor = False + """Used to hide the mouse hover cursor when the user uses the keyboard.""" self._update_count = 0 + """The number of update operations that have occurred. Used for cache invalidation.""" self._header_row_key = RowKey() + """The header is a special row which is not part of the data. This key is used to retrieve it.""" @property def hover_row(self) -> int: - return self.hover_cell.row + """The index of the row that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.row @property def hover_column(self) -> int: - return self.hover_cell.column + """The index of the column that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.column @property def cursor_row(self) -> int: - return self.cursor_cell.row + """The index of the row that the DataTable cursor is currently on.""" + return self.cursor_coordinate.row @property def cursor_column(self) -> int: - return self.cursor_cell.column + """The index of the column that the DataTable cursor is currently on.""" + return self.cursor_coordinate.column @property def row_count(self) -> int: + """The number of rows currently present in the DataTable.""" return len(self.rows) @property def _y_offsets(self) -> list[tuple[RowKey, int]]: + """Contains a 2-tuple for each line (not row!) of the DataTable. Given a y-coordinate, + we can index into this list to find which row that y-coordinate lands on, and the + y-offset *within* that row. The length of the returned list is therefore the total + height of all rows within the DataTable.""" y_offsets: list[tuple[RowKey, int]] = [] - for row in self._ordered_rows: + for row in self.ordered_rows: row_key = row.key row_height = row.height y_offsets += [(row_key, y) for y in range(row_height)] return y_offsets + @property + def _total_row_height(self) -> int: + """The total height of all rows within the DataTable""" + return len(self._y_offsets) + def update_cell( self, row_key: RowKey | str, @@ -459,56 +522,82 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): value: The new value to put inside the cell. update_width: Whether to resize the column width to accommodate for the new cell content. - """ - value = Text.from_markup(value) if isinstance(value, str) else value - self.data[row_key][column_key] = value + Raises: + CellDoesNotExist: When the supplied `row_key` and `column_key` + cannot be found in the table. + """ + try: + self.data[row_key][column_key] = value + except KeyError: + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) from None self._update_count += 1 # Recalculate widths if necessary if update_width: - self._updated_columns.add(column_key) + self._updated_cells.add(CellKey(row_key, column_key)) + self._require_update_dimensions = True - self._require_update_dimensions = True self.refresh() - def update_coordinate(self, coordinate: Coordinate, value: CellType) -> None: - row, column = coordinate - row_key = self._row_locations.get_key(row) - column_key = self._column_locations.get_key(column) - value = Text.from_markup(value) if isinstance(value, str) else value - self.data[row_key][column_key] = value - self._update_count += 1 - self.refresh_cell(row, column) + def update_coordinate( + self, coordinate: Coordinate, value: CellType, *, update_width: bool = False + ) -> None: + """Update the content inside the cell currently occupying the given coordinate. + + Args: + coordinate: The coordinate to update the cell at. + value: The new value to place inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist() + + row_key, column_key = self.coordinate_to_cell_key(coordinate) + self.update_cell(row_key, column_key, value, update_width=update_width) def _get_cells_in_column(self, column_key: ColumnKey) -> Iterable[CellType]: - """For a given column key, return the cells in that column in order""" - for row_metadata in self._ordered_rows: + """For a given column key, return the cells in that column in the + order they currently appear on screen.""" + for row_metadata in self.ordered_rows: row_key = row_metadata.key row = self.data.get(row_key) yield row.get(column_key) - def get_cell_value(self, coordinate: Coordinate) -> CellType: - """Get the value from the cell at the given coordinate. + def get_value_at(self, coordinate: Coordinate) -> CellType: + """Get the value from the cell occupying the given coordinate. Args: coordinate: The coordinate to retrieve the value from. Returns: - The value of the cell. + The value of the cell at the coordinate. Raises: CellDoesNotExist: If there is no cell with the given coordinate. """ - # TODO: Rename to get_value_at()? - # We need to clearly distinguish between coordinates and cell keys - row_index, column_index = coordinate - row_key = self._row_locations.get_key(row_index) - column_key = self._column_locations.get_key(column_index) + row_key, column_key = self.coordinate_to_cell_key(coordinate) + return self.get_cell_value(row_key, column_key) + + def get_cell_value(self, row_key: RowKey, column_key: ColumnKey) -> CellType: + """Given a row key and column key, return the value of the corresponding cell. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The value of the cell identified by the row and column keys. + """ try: cell_value = self.data[row_key][column_key] except KeyError: - raise CellDoesNotExist(f"No cell exists at {coordinate!r}") from None + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) return cell_value def _clear_caches(self) -> None: @@ -518,6 +607,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._styles_cache.clear() def get_row_height(self, row_key: RowKey) -> int: + """Given a row key, return the height of that row in terminal cells. + + Args: + row_key: The key of the row. + + Returns: + The height of the row, measured in terminal character cells. + """ if row_key is self._header_row_key: return self.header_height return self.rows[row_key].height @@ -533,7 +630,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # emit the appropriate [Row|Column|Cell]Highlighted event. self._scroll_cursor_into_view(animate=False) if self.cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_coordinate(self.cursor_coordinate) elif self.cursor_type == "row": self._highlight_row(self.cursor_row) elif self.cursor_type == "column": @@ -548,19 +645,19 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def watch_zebra_stripes(self, zebra_stripes: bool) -> None: self._clear_caches() - def watch_hover_cell(self, old: Coordinate, value: Coordinate) -> None: - self.refresh_cell(*old) - self.refresh_cell(*value) + def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: + self.refresh_coordinate(old) + self.refresh_coordinate(value) - def watch_cursor_cell( + def watch_cursor_coordinate( self, old_coordinate: Coordinate, new_coordinate: Coordinate ) -> None: if old_coordinate != new_coordinate: # Refresh the old and the new cell, and emit the appropriate # message to tell users of the newly highlighted row/cell/column. if self.cursor_type == "cell": - self.refresh_cell(*old_coordinate) - self._highlight_cell(new_coordinate) + self.refresh_coordinate(old_coordinate) + self._highlight_coordinate(new_coordinate) elif self.cursor_type == "row": self.refresh_row(old_coordinate.row) self._highlight_row(new_coordinate.row) @@ -568,36 +665,62 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh_column(old_coordinate.column) self._highlight_column(new_coordinate.column) - def _highlight_cell(self, coordinate: Coordinate) -> None: + def _highlight_coordinate(self, coordinate: Coordinate) -> None: """Apply highlighting to the cell at the coordinate, and emit event.""" - self.refresh_cell(*coordinate) + self.refresh_coordinate(coordinate) try: - cell_value = self.get_cell_value(coordinate) + cell_value = self.get_value_at(coordinate) except CellDoesNotExist: # The cell may not exist e.g. when the table is cleared. # In that case, there's nothing for us to do here. return else: - self.emit_no_wait(DataTable.CellHighlighted(self, cell_value, coordinate)) + cell_key = self.coordinate_to_cell_key(coordinate) + self.emit_no_wait( + DataTable.CellHighlighted( + self, cell_value, coordinate=coordinate, cell_key=cell_key + ) + ) + + def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey: + """Return the key for the cell currently occupying this coordinate in the DataTable + + Args: + coordinate: The coordinate to exam the current cell key of. + + Returns: + The key of the cell currently occupying this coordinate. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist(f"No cell exists at {coordinate!r}.") + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + column_key = self._column_locations.get_key(column_index) + return CellKey(row_key, column_key) def _highlight_row(self, row_index: int) -> None: """Apply highlighting to the row at the given index, and emit event.""" self.refresh_row(row_index) is_valid_row = row_index < len(self.data) if is_valid_row: - self.emit_no_wait(DataTable.RowHighlighted(self, row_index)) + row_key = self._row_locations.get_key(row_index) + self.emit_no_wait(DataTable.RowHighlighted(self, row_index, row_key)) def _highlight_column(self, column_index: int) -> None: """Apply highlighting to the column at the given index, and emit event.""" self.refresh_column(column_index) if column_index < len(self.columns): - self.emit_no_wait(DataTable.ColumnHighlighted(self, column_index)) + column_key = self._column_locations.get_key(column_index) + self.emit_no_wait( + DataTable.ColumnHighlighted(self, column_index, column_key) + ) - def validate_cursor_cell(self, value: Coordinate) -> Coordinate: - return self._clamp_cursor_cell(value) + def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate: + return self._clamp_cursor_coordinate(value) - def _clamp_cursor_cell(self, cursor_cell: Coordinate) -> Coordinate: - row, column = cursor_cell + def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate: + """Clamp a coordinate such that it falls within the boundaries of the table.""" + row, column = coordinate row = clamp(row, 0, self.row_count - 1) column = clamp(column, self.fixed_columns, len(self.columns) - 1) return Coordinate(row, column) @@ -609,34 +732,40 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh cells that were previously impacted by the cursor # but may no longer be. - row_index, column_index = self.cursor_cell if old == "cell": - self.refresh_cell(row_index, column_index) + self.refresh_coordinate(self.cursor_coordinate) elif old == "row": + row_index, _ = self.cursor_coordinate self.refresh_row(row_index) elif old == "column": + _, column_index = self.cursor_coordinate self.refresh_column(column_index) self._scroll_cursor_into_view() def _highlight_cursor(self) -> None: - row_index, column_index = self.cursor_cell + """Applies the appropriate highlighting and raises the appropriate + [Row|Column|Cell]Highlighted event for the given cursor coordinate + and cursor type.""" + row_index, column_index = self.cursor_coordinate cursor_type = self.cursor_type # Apply the highlighting to the newly relevant cells if cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_coordinate(self.cursor_coordinate) elif cursor_type == "row": self._highlight_row(row_index) elif cursor_type == "column": self._highlight_column(column_index) - def _update_column_widths(self, column_keys: set[ColumnKey]) -> None: - for column_key in column_keys: + def _update_column_widths(self, updated_cells: set[CellKey]) -> None: + for row_key, column_key in updated_cells: column = self.columns.get(column_key) console = self.app.console label_width = measure(console, column.label, 1) content_width = column.content_width - new_content_width = measure(console, value, 1) + cell_value = self.data[row_key][column_key] + + new_content_width = measure(console, cell_value, 1) if new_content_width < content_width: cells_in_column = self._get_cells_in_column(column_key) @@ -652,7 +781,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if row_index is None: continue for column, renderable in zip( - self._ordered_columns, self._get_row_renderables(row_index) + self.ordered_columns, self._get_row_renderables(row_index) ): content_width = measure(self.app.console, renderable, 1) column.content_width = max(column.content_width, content_width) @@ -662,26 +791,24 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): header_height = self.header_height if self.show_header else 0 self.virtual_size = Size( total_width, - len(self._y_offsets) + header_height, + self._total_row_height + header_height, ) - def _get_cell_region(self, row_index: int, column_index: int) -> Region: + def _get_cell_region(self, coordinate: Coordinate) -> Region: """Get the region of the cell at the given spatial coordinate.""" - valid_row = 0 <= row_index < len(self.rows) - valid_column = 0 <= column_index < len(self.columns) - valid_cell = valid_row and valid_column - if not valid_cell: + if not self.is_valid_coordinate(coordinate): return Region(0, 0, 0, 0) + row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) row = self.rows[row_key] # The x-coordinate of a cell is the sum of widths of cells to the left. - x = sum(column.render_width for column in self._ordered_columns[:column_index]) + x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = self.columns[column_key].render_width height = row.height - y = sum(ordered_row.height for ordered_row in self._ordered_rows[:row_index]) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height cell_region = Region(x, y, width, height) @@ -689,15 +816,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_row_region(self, row_index: int) -> Region: """Get the region of the row at the given index.""" - rows = self.rows - valid_row = 0 <= row_index < len(rows) - if not valid_row: + if not self.is_valid_row_index(row_index): return Region(0, 0, 0, 0) + rows = self.rows row_key = self._row_locations.get_key(row_index) row = rows[row_key] row_width = sum(column.render_width for column in self.columns.values()) - y = sum(ordered_row.height for ordered_row in self._ordered_rows[:row_index]) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height row_region = Region(0, y, row_width, row.height) @@ -705,16 +831,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_column_region(self, column_index: int) -> Region: """Get the region of the column at the given index.""" - columns = self.columns - valid_column = 0 <= column_index < len(columns) - if not valid_column: + if not self.is_valid_column_index(column_index): return Region(0, 0, 0, 0) - x = sum(column.render_width for column in self._ordered_columns[:column_index]) + columns = self.columns + x = sum(column.render_width for column in self.ordered_columns[:column_index]) column_key = self._column_locations.get_key(column_index) width = columns[column_key].render_width header_height = self.header_height if self.show_header else 0 - height = len(self._y_offsets) + header_height + height = self._total_row_height + header_height full_column_region = Region(x, 0, width, height) return full_column_region @@ -730,10 +855,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.rows.clear() if columns: self.columns.clear() - self._line_no = 0 self._require_update_dimensions = True - self.cursor_cell = Coordinate(0, 0) - self.hover_cell = Coordinate(0, 0) + self.cursor_coordinate = Coordinate(0, 0) + self.hover_coordinate = Coordinate(0, 0) self.refresh() def add_column( @@ -751,15 +875,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): of its current location in the DataTable (it could have moved after being added due to sorting or insertion/deletion of other columns). """ - text_label = Text.from_markup(label) if isinstance(label, str) else label - column_key = ColumnKey(key) column_index = len(self.columns) - content_width = measure(self.app.console, text_label, 1) + content_width = measure(self.app.console, label, 1) if width is None: column = Column( column_key, - text_label, + label, content_width, content_width=content_width, auto_width=True, @@ -767,7 +889,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: column = Column( column_key, - text_label, + label, width, content_width=content_width, ) @@ -804,13 +926,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Map the key of this row to its current index self._row_locations[row_key] = row_index self.data[row_key] = { - column.key: Text(cell) if isinstance(cell, str) else cell - for column, cell in zip_longest(self._ordered_columns, cells) + column.key: cell + for column, cell in zip_longest(self.ordered_columns, cells) } self.rows[row_key] = Row(row_key, height) self._new_rows.add(row_key) self._require_update_dimensions = True - self.cursor_cell = self.cursor_cell + self.cursor_coordinate = self.cursor_coordinate # If a position has opened for the cursor to appear, where it previously # could not (e.g. when there's no data in the table), then a highlighted @@ -832,8 +954,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: A list of the keys for the columns that were added. See - the `add_column` method docstring for more information on how - these keys are used. + the `add_column` method docstring for more information on how + these keys are used. """ column_keys = [] for label in labels: @@ -849,8 +971,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: A list of the keys for the rows that were added. See - the `add_row` method docstring for more information on how - these keys are used. + the `add_row` method docstring for more information on how + these keys are used. """ row_keys = [] for row in rows: @@ -859,29 +981,35 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return row_keys def on_idle(self) -> None: + """Runs when the message pump is empty. We use this for + some expensive calculations like re-computing dimensions of the + whole DataTable and re-computing column widths after some cells + have been updated. This is more efficient in the case of high + frequency updates, ensuring we only do expensive computations once.""" if self._require_update_dimensions: + # Add the new rows *before* updating the column widths, since + # cells in a new row may influence the final width of a column self._require_update_dimensions = False new_rows = self._new_rows.copy() self._new_rows.clear() - # Add the new rows *before* updating the column widths, since - # cells in a new row may influence the final width of a column self._update_dimensions(new_rows) - if self._updated_columns: - updated_columns = self._updated_columns.copy() - self._updated_columns.clear() - self._update_column_widths(updated_columns) - self.refresh() - def refresh_cell(self, row_index: int, column_index: int) -> None: - """Refresh a cell. + if self._updated_cells: + # Cell contents have already been updated at this point. + # Now we only need to worry about measuring column widths. + updated_columns = self._updated_cells.copy() + self._updated_cells.clear() + self._update_column_widths(updated_columns) + + def refresh_coordinate(self, coordinate: Coordinate) -> None: + """Refresh the cell at a coordinate. Args: - row_index: Row index. - column_index: Column index. + coordinate: The coordinate to refresh. """ - if row_index < 0 or column_index < 0: + if not self.is_valid_coordinate(coordinate): return - region = self._get_cell_region(row_index, column_index) + region = self._get_cell_region(coordinate) self._refresh_region(region) def refresh_row(self, row_index: int) -> None: @@ -890,7 +1018,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: row_index: The index of the row to refresh. """ - if row_index < 0 or row_index >= len(self.rows): + if not self.is_valid_row_index(row_index): return region = self._get_row_region(row_index) @@ -902,7 +1030,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: column_index: The index of the column to refresh. """ - if column_index < 0 or column_index >= len(self.columns): + if not self.is_valid_column_index(column_index): return region = self._get_column_region(column_index) @@ -917,8 +1045,45 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): region = region.translate(-self.scroll_offset) self.refresh(region) + def is_valid_row_index(self, row_index: int) -> bool: + """Return a boolean indicating whether the row_index is within table bounds. + + Args: + row_index: The row index to check. + + Returns: + True if the row index is within the bounds of the table. + """ + return 0 <= row_index < len(self.rows) + + def is_valid_column_index(self, column_index: int) -> bool: + """Return a boolean indicating whether the column_index is within table bounds. + + Args: + column_index: The column index to check. + + Returns: + True if the column index is within the bounds of the table. + """ + return 0 <= column_index < len(self.columns) + + def is_valid_coordinate(self, coordinate: Coordinate) -> bool: + """Return a boolean indicating whether the given coordinate is within table bounds. + + Args: + coordinate: The coordinate to validate. + + Returns: + True if the coordinate is within the bounds of the table. + """ + row_index, column_index = coordinate + return self.is_valid_row_index(row_index) and self.is_valid_column_index( + column_index + ) + @property - def _ordered_columns(self) -> list[Column]: + def ordered_columns(self) -> list[Column]: + """The list of Columns in the DataTable, ordered as they currently appear on screen.""" column_indices = range(len(self.columns)) column_keys = [ self._column_locations.get_key(index) for index in column_indices @@ -927,7 +1092,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return ordered_columns @property - def _ordered_rows(self) -> list[Row]: + def ordered_rows(self) -> list[Row]: + """The list of Rows in the DataTable, ordered as they currently appear on screen.""" row_indices = range(self.row_count) ordered_rows = [] for row_index in row_indices: @@ -947,7 +1113,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """ # TODO: We have quite a few back and forward key/index conversions, could probably reduce them - ordered_columns = self._ordered_columns + ordered_columns = self.ordered_columns if row_index == -1: row = [column.label for column in ordered_columns] return row @@ -1013,8 +1179,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if is_fixed_style: style += self.get_component_styles("datatable--cursor-fixed").rich_style - # TODO: We can hoist `row_key` lookup waaay up to do it inside `_get_offsets` - # then just pass it through to here instead of the row_index. row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) cell_cache_key = (row_key, column_key, style, cursor, hover, self._update_count) @@ -1039,12 +1203,12 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): cursor_location: Coordinate, hover_location: Coordinate, ) -> tuple[SegmentLines, SegmentLines]: - """Render a row in to lines for each cell. + """Render a single line from a row in the DataTable. Args: row_key: The identifying key for this row. line_no: Line number (y-coordinate) within row. 0 is the first strip of - cells in the row, line_no=1 is the next, and so on... + cells in the row, line_no=1 is the next line in the row, and so on... base_style: Base style of row. cursor_location: The location of the cursor in the DataTable. hover_location: The location of the hover cursor in the DataTable. @@ -1098,7 +1262,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): fixed_style += Style.from_meta({"fixed": True}) fixed_row = [] for column_index, column in enumerate( - self._ordered_columns[: self.fixed_columns] + self.ordered_columns[: self.fixed_columns] ): cell_location = Coordinate(row_index, column_index) fixed_cell_lines = render_cell( @@ -1115,7 +1279,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: fixed_row = [] - if row_key is None: + is_header_row = row_key is self._header_row_key + if is_header_row: row_style = self.get_component_styles("datatable--header").rich_style else: if self.zebra_stripes: @@ -1127,7 +1292,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_style = base_style scrollable_row = [] - for column_index, column in enumerate(self._ordered_columns): + for column_index, column in enumerate(self.ordered_columns): cell_location = Coordinate(row_index, column_index) cell_lines = render_cell( row_index, @@ -1143,7 +1308,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_render_cache[cache_key] = row_pair return row_pair - def _get_offsets(self, y: int) -> tuple[RowKey | None, int]: + def _get_offsets(self, y: int) -> tuple[RowKey, int]: """Get row key and line offset for a given line. Args: @@ -1156,7 +1321,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): y_offsets = self._y_offsets if self.show_header: if y < header_height: - return None, y + return self._header_row_key, y y -= header_height if y > len(y_offsets): raise LookupError("Y coord {y!r} is greater than total height") @@ -1164,7 +1329,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return y_offsets[y] def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: - """Render a line in to a list of segments. + """Render a (possibly cropped) line in to a Strip (a list of segments + representing a horizontal line). Args: y: Y coordinate of line @@ -1173,7 +1339,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): base_style: Style to apply to line. Returns: - List of segments for rendering. + The Strip which represents this cropped line. """ width = self.size.width @@ -1188,8 +1354,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x1, x2, width, - self.cursor_cell, - self.hover_cell, + self.cursor_coordinate, + self.hover_coordinate, base_style, self.cursor_type, self._show_hover_cursor, @@ -1202,12 +1368,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_key, y_offset_in_row, base_style, - cursor_location=self.cursor_cell, - hover_location=self.hover_cell, + cursor_location=self.cursor_coordinate, + hover_location=self.hover_coordinate, ) fixed_width = sum( - column.render_width - for column in self._ordered_columns[: self.fixed_columns] + column.render_width for column in self.ordered_columns[: self.fixed_columns] ) fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] @@ -1240,15 +1405,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): return self._render_line(y, scroll_x, scroll_x + width, self.rich_style) def on_mouse_move(self, event: events.MouseMove): + """If the hover cursor is visible, display it by extracting the row + and column metadata from the segments present in the cells.""" self._set_hover_cursor(True) meta = event.style.meta if meta and self.show_cursor and self.cursor_type != "none": try: - self.hover_cell = Coordinate(meta["row"], meta["column"]) + self.hover_coordinate = Coordinate(meta["row"], meta["column"]) except KeyError: pass def _get_fixed_offset(self) -> Spacing: + """Calculate the "fixed offset", that is the space to the top and left + that is occupied by fixed rows and columns respectively. Fixed rows and columns + are rows and columns that do not participate in scrolling.""" top = self.header_height if self.show_header else 0 top += sum( self.rows[self._row_locations.get_key(row_index)].height @@ -1256,8 +1426,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if row_index in self.rows ) left = sum( - column.render_width - for column in self._ordered_columns[: self.fixed_columns] + column.render_width for column in self.ordered_columns[: self.fixed_columns] ) return Spacing(top, 0, 0, left) @@ -1297,6 +1466,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh() def _scroll_cursor_into_view(self, animate: bool = False) -> None: + """When the cursor is at a boundary of the DataTable and moves out + of view, this method handles scrolling to ensure it remains visible.""" fixed_offset = self._get_fixed_offset() top, _, _, left = fixed_offset @@ -1307,7 +1478,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x, y, width, height = self._get_column_region(self.cursor_column) region = Region(x, int(self.scroll_y) + top, width, height - top) else: - region = self._get_cell_region(self.cursor_row, self.cursor_column) + region = self._get_cell_region(self.cursor_coordinate) self.scroll_to_region(region, animate=animate, spacing=fixed_offset) @@ -1327,7 +1498,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "row": self.refresh_row(self.hover_row) elif cursor_type == "cell": - self.refresh_cell(*self.hover_cell) + self.refresh_coordinate(self.hover_coordinate) def on_click(self, event: events.Click) -> None: self._set_hover_cursor(True) @@ -1336,7 +1507,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._emit_selected_message() meta = self.get_style_at(event.x, event.y).meta if meta: - self.cursor_cell = Coordinate(meta["row"], meta["column"]) + self.cursor_coordinate = Coordinate(meta["row"], meta["column"]) self._scroll_cursor_into_view(animate=True) event.stop() @@ -1344,7 +1515,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.up() + self.cursor_coordinate = self.cursor_coordinate.up() self._scroll_cursor_into_view() else: # If the cursor doesn't move up (e.g. column cursor can't go up), @@ -1355,7 +1526,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.down() + self.cursor_coordinate = self.cursor_coordinate.down() self._scroll_cursor_into_view() else: super().action_scroll_down() @@ -1364,7 +1535,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.right() + self.cursor_coordinate = self.cursor_coordinate.right() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_right() @@ -1373,7 +1544,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.left() + self.cursor_coordinate = self.cursor_coordinate.left() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_left() @@ -1385,19 +1556,23 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _emit_selected_message(self): """Emit the appropriate message for a selection based on the `cursor_type`.""" - cursor_cell = self.cursor_cell + cursor_coordinate = self.cursor_coordinate cursor_type = self.cursor_type + cell_key = self.coordinate_to_cell_key(cursor_coordinate) if cursor_type == "cell": self.emit_no_wait( DataTable.CellSelected( self, - self.get_cell_value(cursor_cell), - cursor_cell, + self.get_value_at(cursor_coordinate), + coordinate=cursor_coordinate, + cell_key=cell_key, ) ) elif cursor_type == "row": - row, _ = cursor_cell - self.emit_no_wait(DataTable.RowSelected(self, row)) + row_index, _ = cursor_coordinate + row_key, _ = cell_key + self.emit_no_wait(DataTable.RowSelected(self, row_index, row_key)) elif cursor_type == "column": - _, column = cursor_cell - self.emit_no_wait(DataTable.ColumnSelected(self, column)) + _, column_index = cursor_coordinate + _, column_key = cell_key + self.emit_no_wait(DataTable.ColumnSelected(self, column_index, column_key)) diff --git a/src/textual/widgets/data_table.py b/src/textual/widgets/data_table.py index d0316f387..429724361 100644 --- a/src/textual/widgets/data_table.py +++ b/src/textual/widgets/data_table.py @@ -1,5 +1,23 @@ """Make non-widget DataTable support classes available.""" -from ._data_table import Column, Row +from ._data_table import ( + Column, + Row, + RowKey, + ColumnKey, + CellKey, + CursorType, + CellType, + CellDoesNotExist, +) -__all__ = ["Column", "Row"] +__all__ = [ + "Column", + "Row", + "RowKey", + "ColumnKey", + "CellKey", + "CursorType", + "CellType", + "CellDoesNotExist", +] diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 1c53f9b6a..1555646de 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,12 +1,15 @@ +import asyncio + import pytest from rich.text import Text +from textual._wait import wait_for_idle from textual.app import App from textual.coordinate import Coordinate from textual.message import Message from textual.widgets import DataTable -from textual.widgets._data_table import ( - StringKey, +from textual.widgets._data_table import CellKey +from textual.widgets.data_table import ( CellDoesNotExist, RowKey, Row, @@ -155,14 +158,14 @@ async def test_add_rows_user_defined_keys(): assert isinstance(algernon_key, RowKey) # Ensure the data in the table is mapped as expected - first_row = {key_a: Text(ROWS[0][0]), key_b: Text(ROWS[0][1])} + first_row = {key_a: ROWS[0][0], key_b: ROWS[0][1]} assert table.data[algernon_key] == first_row assert table.data["algernon"] == first_row - second_row = {key_a: Text(ROWS[1][0]), key_b: Text(ROWS[1][1])} + second_row = {key_a: ROWS[1][0], key_b: ROWS[1][1]} assert table.data["charlie"] == second_row - third_row = {key_a: Text(ROWS[2][0]), key_b: Text(ROWS[2][1])} + third_row = {key_a: ROWS[2][0], key_b: ROWS[2][1]} assert table.data[auto_key] == third_row first_row = Row(algernon_key, height=1) @@ -179,7 +182,6 @@ async def test_add_columns(): assert len(table.columns) == 3 -# TODO: Ensure we can use the key to retrieve the column. async def test_add_columns_user_defined_keys(): app = DataTableApp() async with app.run_test(): @@ -193,19 +195,19 @@ async def test_clear(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Add some data and update cursor positions table.add_column("Column0") table.add_rows([["Row0"], ["Row1"], ["Row2"]]) - table.cursor_cell = Coordinate(1, 0) - table.hover_cell = Coordinate(2, 0) + table.cursor_coordinate = Coordinate(1, 0) + table.hover_coordinate = Coordinate(2, 0) # Ensure the cursor positions are reset to origin on clear() table.clear() - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Ensure that the table has been cleared assert table.data == {} @@ -224,57 +226,176 @@ async def test_column_labels() -> None: table = app.query_one(DataTable) table.add_columns("1", "2", "3") actual_labels = [col.label for col in table.columns.values()] - expected_labels = [Text("1"), Text("2"), Text("3")] + expected_labels = ["1", "2", "3"] assert actual_labels == expected_labels -async def test_column_widths() -> None: +async def test_initial_column_widths() -> None: app = DataTableApp() - async with app.run_test() as pilot: + async with app.run_test(): table = app.query_one(DataTable) foo, bar = table.add_columns("foo", "bar") assert table.columns[foo].width == 3 assert table.columns[bar].width == 3 table.add_row("Hello", "World!") - await pilot.pause() + await wait_for_idle() assert table.columns[foo].content_width == 5 assert table.columns[bar].content_width == 6 table.add_row("Hello World!!!", "fo") - await pilot.pause() + await wait_for_idle() assert table.columns[foo].content_width == 14 assert table.columns[bar].content_width == 6 async def test_get_cell_value_returns_value_at_cell(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + assert table.get_cell_value("R1", "C1") == "TargetValue" + + +async def test_get_cell_value_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell_value("INVALID_ROW", "C1") + + +async def test_get_cell_value_invalid_column_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell_value("R1", "INVALID_COLUMN") + + +async def test_get_value_at_returns_value_at_cell(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) table.add_columns("A", "B") table.add_rows(ROWS) - assert table.get_cell_value(Coordinate(0, 0)) == Text("0/0") + assert table.get_value_at(Coordinate(0, 0)) == "0/0" -async def test_get_cell_value_exception(): +async def test_get_value_at_exception(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) table.add_columns("A", "B") table.add_rows(ROWS) with pytest.raises(CellDoesNotExist): - table.get_cell_value(Coordinate(9999, 0)) + table.get_value_at(Coordinate(9999, 0)) + + +async def test_update_cell_cell_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + table.update_cell("1", "A", "NEW_VALUE") + assert table.get_cell_value("1", "A") == "NEW_VALUE" + + +async def test_update_cell_cell_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + with pytest.raises(CellDoesNotExist): + table.update_cell("INVALID", "CELL", "Value") + + +async def test_update_coordinate_coordinate_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_0, column_1 = table.add_columns("A", "B") + row_0, *_ = table.add_rows(ROWS) + + table.update_coordinate(Coordinate(0, 1), "newvalue") + assert table.get_cell_value(row_0, column_1) == "newvalue" + + +async def test_update_coordinate_coordinate_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + with pytest.raises(CellDoesNotExist): + table.update_coordinate(Coordinate(999, 999), "newvalue") + + +@pytest.mark.parametrize( + "label,new_value,new_content_width", + [ + # Shorter than initial cell value, larger than label => width remains same + ("A", "BB", 3), + # Larger than initial cell value, shorter than label => width remains that of label + ("1234567", "1234", 7), + # Shorter than initial cell value, shorter than label => width remains same + ("12345", "123", 5), + # Larger than initial cell value, larger than label => width updates to new cell value + ("12345", "123456789", 9), + ], +) +async def test_update_coordinate_column_width(label, new_value, new_content_width): + # Initial cell values are length 3. Let's update cell content and ensure + # that the width of the column is correct given the new cell content widths + # and the label of the column the cell is in. + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + key, _ = table.add_columns(label, "Column2") + table.add_rows(ROWS) + first_column = table.columns.get(key) + + table.update_coordinate(Coordinate(0, 0), new_value, update_width=True) + await wait_for_idle() + assert first_column.content_width == new_content_width + assert first_column.render_width == new_content_width + 2 + + +async def test_coordinate_to_cell_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_key, _ = table.add_columns("Column0", "Column1") + row_key = table.add_row("A", "B") + + cell_key = table.coordinate_to_cell_key(Coordinate(0, 0)) + assert cell_key == CellKey(row_key, column_key) + + +async def test_coordinate_to_cell_key_invalid_coordinate(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(CellDoesNotExist): + table.coordinate_to_cell_key(Coordinate(9999, 9999)) def test_key_equals_equivalent_string(): text = "Hello" - key = StringKey(text) + key = RowKey(text) assert key == text assert hash(key) == hash(text) def test_key_doesnt_match_non_equal_string(): - key = StringKey("123") + key = ColumnKey("123") text = "laksjdlaskjd" assert key != text assert hash(key) != hash(text) @@ -293,9 +414,9 @@ def test_key_string_lookup(): # in tests how we intend for the keys to work for cache lookups. dictionary = { "foo": "bar", - StringKey("hello"): "world", + RowKey("hello"): "world", } assert dictionary["foo"] == "bar" - assert dictionary[StringKey("foo")] == "bar" + assert dictionary[RowKey("foo")] == "bar" assert dictionary["hello"] == "world" - assert dictionary[StringKey("hello")] == "world" + assert dictionary[RowKey("hello")] == "world"