diff --git a/.coveragerc b/.coveragerc index d16dd221a..087a1674f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,3 +7,4 @@ exclude_lines = if TYPE_CHECKING: if __name__ == "__main__": @overload + __rich_repr__ diff --git a/CHANGELOG.md b/CHANGELOG.md index 6df3da39e..e08383776 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,23 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added Shift+scroll wheel and ctrl+scroll wheel to scroll horizontally - Added `Tree.action_toggle_node` to toggle a node without selecting, and bound it to Space https://github.com/Textualize/textual/issues/1433 - Added `Tree.reset` to fully reset a `Tree` https://github.com/Textualize/textual/issues/1437 +- Added `DataTable.sort` to sort rows https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.get_cell` to retrieve a cell by column/row keys https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.get_cell_at` to retrieve a cell by coordinate https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.update_cell` to update a cell by column/row keys https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.update_cell_at` to update a cell at a coordinate https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.ordered_rows` property to retrieve `Row`s as they're currently ordered https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.ordered_columns` property to retrieve `Column`s as they're currently ordered https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.coordinate_to_cell_key` to find the key for the cell at a coordinate https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.is_valid_coordinate` https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.is_valid_row_index` https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.is_valid_column_index` https://github.com/Textualize/textual/pull/1638 +- Added attributes to events emitted from `DataTable` indicating row/column/cell keys https://github.com/Textualize/textual/pull/1638 +- Added `DataTable.get_row` to retrieve the values from a row by key https://github.com/Textualize/textual/pull/1786 +- Added `DataTable.get_row_at` to retrieve the values from a row by index https://github.com/Textualize/textual/pull/1786 +- Added `DataTable.get_column` to retrieve the values from a column by key https://github.com/Textualize/textual/pull/1786 +- Added `DataTable.get_column_at` to retrieve the values from a column by index https://github.com/Textualize/textual/pull/1786 +- Added `DataTable.HeaderSelected` which is posted when header label clicked https://github.com/Textualize/textual/pull/1788 - Added `DOMNode.watch` and `DOMNode.is_attached` methods https://github.com/Textualize/textual/pull/1750 - Added `DOMNode.css_tree` which is a renderable that shows the DOM and CSS https://github.com/Textualize/textual/pull/1778 - Added `DOMNode.children_view` which is a view on to a nodes children list, use for querying https://github.com/Textualize/textual/pull/1778 @@ -27,6 +44,21 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Breaking change: `TreeNode` can no longer be imported from `textual.widgets`; it is now available via `from textual.widgets.tree import TreeNode`. https://github.com/Textualize/textual/pull/1637 - `Tree` now shows a (subdued) cursor for a highlighted node when focus has moved elsewhere https://github.com/Textualize/textual/issues/1471 +- `DataTable.add_row` now accepts `key` argument to uniquely identify the row https://github.com/Textualize/textual/pull/1638 +- `DataTable.add_column` now accepts `key` argument to uniquely identify the column https://github.com/Textualize/textual/pull/1638 +- `DataTable.add_row` and `DataTable.add_column` now return lists of keys identifying the added rows/columns https://github.com/Textualize/textual/pull/1638 +- Breaking change: `DataTable.get_cell_value` renamed to `DataTable.get_value_at` https://github.com/Textualize/textual/pull/1638 +- `DataTable.row_count` is now a property https://github.com/Textualize/textual/pull/1638 +- Breaking change: `DataTable.cursor_cell` renamed to `DataTable.cursor_coordinate` https://github.com/Textualize/textual/pull/1638 + - The method `validate_cursor_cell` was renamed to `validate_cursor_coordinate`. + - The method `watch_cursor_cell` was renamed to `watch_cursor_coordinate`. +- Breaking change: `DataTable.hover_cell` renamed to `DataTable.hover_coordinate` https://github.com/Textualize/textual/pull/1638 + - The method `validate_hover_cell` was renamed to `validate_hover_coordinate`. +- Breaking change: `DataTable.data` structure changed, and will be made private in upcoming release https://github.com/Textualize/textual/pull/1638 +- Breaking change: `DataTable.refresh_cell` was renamed to `DataTable.refresh_coordinate` https://github.com/Textualize/textual/pull/1638 +- Breaking change: `DataTable.get_row_height` now takes a `RowKey` argument instead of a row index https://github.com/Textualize/textual/pull/1638 +- Breaking change: `DataTable.data` renamed to `DataTable._data` (it's now private) https://github.com/Textualize/textual/pull/1786 +- The `_filter` module was made public (now called `filter`) https://github.com/Textualize/textual/pull/1638 - Breaking change: renamed `Checkbox` to `Switch` https://github.com/Textualize/textual/issues/1746 - `App.install_screen` name is no longer optional https://github.com/Textualize/textual/pull/1778 - `App.query` now only includes the current screen https://github.com/Textualize/textual/pull/1778 @@ -47,6 +79,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed issue with renderable width calculation https://github.com/Textualize/textual/issues/1685 - Fixed issue with app not processing Paste event https://github.com/Textualize/textual/issues/1666 - Fixed glitch with view position with auto width inputs https://github.com/Textualize/textual/issues/1693 +- Fixed `DataTable` "selected" events containing wrong coordinates when mouse was used https://github.com/Textualize/textual/issues/1723 ### Removed diff --git a/docs/examples/widgets/data_table.py b/docs/examples/widgets/data_table.py index 74d9b76e4..f409bb45c 100644 --- a/docs/examples/widgets/data_table.py +++ b/docs/examples/widgets/data_table.py @@ -1,18 +1,18 @@ -import csv -import io - from textual.app import App, ComposeResult from textual.widgets import DataTable -CSV = """lane,swimmer,country,time -4,Joseph Schooling,Singapore,50.39 -2,Michael Phelps,United States,51.14 -5,Chad le Clos,South Africa,51.14 -6,László Cseh,Hungary,51.14 -3,Li Zhuhao,China,51.26 -8,Mehdy Metella,France,51.58 -7,Tom Shields,United States,51.73 -1,Aleksandr Sadovnikov,Russia,51.84""" +ROWS = [ + ("lane", "swimmer", "country", "time"), + (4, "Joseph Schooling", "Singapore", 50.39), + (2, "Michael Phelps", "United States", 51.14), + (5, "Chad le Clos", "South Africa", 51.14), + (6, "László Cseh", "Hungary", 51.14), + (3, "Li Zhuhao", "China", 51.26), + (8, "Mehdy Metella", "France", 51.58), + (7, "Tom Shields", "United States", 51.73), + (1, "Aleksandr Sadovnikov", "Russia", 51.84), + (10, "Darren Burns", "Scotland", 51.84), +] class TableApp(App): @@ -21,11 +21,11 @@ class TableApp(App): def on_mount(self) -> None: table = self.query_one(DataTable) - rows = csv.reader(io.StringIO(CSV)) + rows = iter(ROWS) table.add_columns(*next(rows)) table.add_rows(rows) +app = TableApp() if __name__ == "__main__": - app = TableApp() app.run() diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md index 9dd918d9d..7cc0a00d6 100644 --- a/docs/widgets/data_table.md +++ b/docs/widgets/data_table.md @@ -22,17 +22,17 @@ The example below populates a table with CSV data. ## Reactive Attributes -| Name | Type | Default | Description | -|-----------------|---------------------------------------------|--------------------|---------------------------------------------------------| -| `show_header` | `bool` | `True` | Show the table header | -| `fixed_rows` | `int` | `0` | Number of fixed rows (rows which do not scroll) | -| `fixed_columns` | `int` | `0` | Number of fixed columns (columns which do not scroll) | -| `zebra_stripes` | `bool` | `False` | Display alternating colors on rows | -| `header_height` | `int` | `1` | Height of header row | -| `show_cursor` | `bool` | `True` | Show the cursor | -| `cursor_type` | `str` | `"cell"` | One of `"cell"`, `"row"`, `"column"`, or `"none"` | -| `cursor_cell` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinates of the cell the cursor is currently on | -| `hover_cell` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinates of the cell the _mouse_ cursor is above | +| Name | Type | Default | Description | +|---------------------|---------------------------------------------|--------------------|-------------------------------------------------------| +| `show_header` | `bool` | `True` | Show the table header | +| `fixed_rows` | `int` | `0` | Number of fixed rows (rows which do not scroll) | +| `fixed_columns` | `int` | `0` | Number of fixed columns (columns which do not scroll) | +| `zebra_stripes` | `bool` | `False` | Display alternating colors on rows | +| `header_height` | `int` | `1` | Height of header row | +| `show_cursor` | `bool` | `True` | Show the cursor | +| `cursor_type` | `str` | `"cell"` | One of `"cell"`, `"row"`, `"column"`, or `"none"` | +| `cursor_coordinate` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The current coordinate of the cursor | +| `hover_coordinate` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinate the _mouse_ cursor is above | ## Messages @@ -48,6 +48,8 @@ The example below populates a table with CSV data. ### ::: textual.widgets.DataTable.ColumnSelected +### ::: textual.widgets.DataTable.HeaderSelected + ## Bindings The data table widget defines directly the following bindings: diff --git a/src/textual/_styles_cache.py b/src/textual/_styles_cache.py index 77cbfb4de..c760fcfe3 100644 --- a/src/textual/_styles_cache.py +++ b/src/textual/_styles_cache.py @@ -8,7 +8,7 @@ from rich.segment import Segment from rich.style import Style from ._border import get_box, render_row -from ._filter import LineFilter +from .filter import LineFilter from ._opacity import _apply_opacity from ._segment_tools import line_pad, line_trim from .color import Color diff --git a/src/textual/_two_way_dict.py b/src/textual/_two_way_dict.py new file mode 100644 index 000000000..d733edcdc --- /dev/null +++ b/src/textual/_two_way_dict.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +Key = TypeVar("Key") +Value = TypeVar("Value") + + +class TwoWayDict(Generic[Key, Value]): + """ + A two-way mapping offering O(1) access in both directions. + + Wraps two dictionaries and uses them to provide efficient access to + both values (given keys) and keys (given values). + """ + + def __init__(self, initial: dict[Key, Value]) -> None: + self._forward: dict[Key, Value] = initial + self._reverse: dict[Value, Key] = {value: key for key, value in initial.items()} + + def __setitem__(self, key: Key, value: Value) -> None: + # TODO: Duplicate values need to be managed to ensure consistency, + # decide on best approach. + self._forward.__setitem__(key, value) + self._reverse.__setitem__(value, key) + + def __delitem__(self, key: Key) -> None: + value = self._forward[key] + self._forward.__delitem__(key) + self._reverse.__delitem__(value) + + def get(self, key: Key) -> Value: + """Given a key, efficiently lookup and return the associated value. + + Args: + key: The key + + Returns: + The value + """ + return self._forward.get(key) + + def get_key(self, value: Value) -> Key: + """Given a value, efficiently lookup and return the associated key. + + Args: + value: The value + + Returns: + The key + """ + return self._reverse.get(value) + + def contains_value(self, value: Value) -> bool: + """Check if `value` is a value within this TwoWayDict. + + Args: + value: The value to check. + + Returns: + True if the value is within the values of this dict. + """ + return value in self._reverse + + def __len__(self): + return len(self._forward) + + def __contains__(self, item: Key) -> bool: + return item in self._forward diff --git a/src/textual/app.py b/src/textual/app.py index 0a425c839..6bca62824 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -48,7 +48,6 @@ from ._asyncio import create_task from ._callback import invoke from ._context import active_app from ._event_broker import NoHandler, extract_handler_actions -from ._filter import LineFilter, Monochrome from ._path import _make_path_object_relative from ._wait import wait_for_idle from .actions import SkipAction @@ -62,6 +61,7 @@ from .driver import Driver from .drivers.headless_driver import HeadlessDriver from .features import FeatureFlag, parse_features from .file_monitor import FileMonitor +from .filter import LineFilter, Monochrome from .geometry import Offset, Region, Size from .keys import REPLACED_KEYS, _get_key_display from .messages import CallbackType diff --git a/src/textual/_filter.py b/src/textual/filter.py similarity index 100% rename from src/textual/_filter.py rename to src/textual/filter.py diff --git a/src/textual/message.py b/src/textual/message.py index 5f46cc97a..7ed72789b 100644 --- a/src/textual/message.py +++ b/src/textual/message.py @@ -10,7 +10,6 @@ from .case import camel_to_snake if TYPE_CHECKING: from .message_pump import MessagePump - from .widget import Widget @rich.repr.auto diff --git a/src/textual/render.py b/src/textual/render.py index 8911c4263..c1003b062 100644 --- a/src/textual/render.py +++ b/src/textual/render.py @@ -1,5 +1,6 @@ from __future__ import annotations +from rich.cells import cell_len from rich.console import Console, RenderableType from rich.protocol import rich_cast @@ -22,6 +23,9 @@ def measure( Returns: Width in cells """ + if isinstance(renderable, str): + return cell_len(renderable) + width = default renderable = rich_cast(renderable) get_console_width = getattr(renderable, "__rich_measure__", None) diff --git a/src/textual/strip.py b/src/textual/strip.py index c10a649a9..53d8a8e9f 100644 --- a/src/textual/strip.py +++ b/src/textual/strip.py @@ -9,7 +9,7 @@ from rich.segment import Segment from rich.style import Style, StyleType from ._cache import FIFOCache -from ._filter import LineFilter +from .filter import LineFilter from ._segment_tools import index_to_cell_position diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 4cc540c94..b9e48d326 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -1,8 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass, field +import functools +from dataclasses import dataclass from itertools import chain, zip_longest -from typing import Generic, Iterable, cast +from operator import itemgetter +from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast import rich.repr from rich.console import RenderableType @@ -11,11 +13,12 @@ from rich.protocol import is_renderable from rich.segment import Segment from rich.style import Style from rich.text import Text, TextType -from typing_extensions import ClassVar, Literal, TypeVar +from typing_extensions import Literal, TypeAlias -from .. import events, messages +from .. import events from .._cache import LRUCache from .._segment_tools import line_crop +from .._two_way_dict import TwoWayDict from .._types import SegmentLines from ..binding import Binding, BindingType from ..coordinate import Coordinate @@ -26,40 +29,140 @@ from ..render import measure from ..scroll_view import ScrollView from ..strip import Strip +CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]" +LineCacheKey: TypeAlias = ( + "tuple[int, int, int, int, Coordinate, Coordinate, Style, CursorType, bool, int]" +) +RowCacheKey: TypeAlias = ( + "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int]" +) CursorType = Literal["cell", "row", "column", "none"] -CELL: CursorType = "cell" CellType = TypeVar("CellType") class CellDoesNotExist(Exception): - pass + """The cell key/index was invalid. + + Raised when the user supplies coordinates or cell keys which + do not exist in the DataTable.""" -def default_cell_formatter(obj: object) -> RenderableType | None: - """Format a cell in to a renderable. +class RowDoesNotExist(Exception): + """Raised when the user supplies a row index or row key which does + not exist in the DataTable (e.g. out of bounds index, invalid key)""" + + +class ColumnDoesNotExist(Exception): + """Raised when the user supplies a column index or column key which does + not exist in the DataTable (e.g. out of bounds index, invalid key)""" + + +class DuplicateKey(Exception): + """The key supplied already exists. + + Raised when the RowKey or ColumnKey provided already refers to + an existing row or column in the DataTable. Keys must be unique.""" + + +@functools.total_ordering +class StringKey: + """An object used as a key in a mapping. + + It can optionally wrap a string, + and lookups into a map using the object behave the same as lookups using + the string itself.""" + + value: str | None + + def __init__(self, value: str | None = None): + self.value = value + + def __hash__(self): + # If a string is supplied, we use the hash of the string. If no string was + # supplied, we use the default hash to ensure uniqueness amongst instances. + return hash(self.value) if self.value is not None else id(self) + + def __eq__(self, other: object) -> bool: + # Strings will match Keys containing the same string value. + # Otherwise, you'll need to supply the exact same key object. + if isinstance(other, str): + return self.value == other + elif isinstance(other, StringKey): + if self.value is not None and other.value is not None: + return self.value == other.value + else: + return hash(self) == hash(other) + else: + raise NotImplemented + + def __lt__(self, other): + if isinstance(other, str): + return self.value < other + elif isinstance(other, StringKey): + return self.value < other.value + else: + raise NotImplemented + + def __rich_repr__(self): + yield "value", self.value + + +class RowKey(StringKey): + """Uniquely identifies a row in the DataTable. + + Even if the visual location + of the row changes due to sorting or other modifications, a key will always + refer to the same row.""" + + +class ColumnKey(StringKey): + """Uniquely identifies a column in the DataTable. + + Even if the visual location + of the column changes due to sorting or other modifications, a key will always + refer to the same column.""" + + +class CellKey(NamedTuple): + """A unique identifier for a cell in the DataTable. + + Even if the cell changes + visual location (i.e. moves to a different coordinate in the table), this key + can still be used to retrieve it, regardless of where it currently is.""" + + row_key: RowKey + column_key: ColumnKey + + def __rich_repr__(self): + yield "row_key", self.row_key + yield "column_key", self.column_key + + +def default_cell_formatter(obj: object) -> RenderableType: + """Convert a cell into a Rich renderable for display. Args: obj: Data for a cell. Returns: - A renderable or None if the object could not be rendered. + A renderable to be displayed which represents the data. """ if isinstance(obj, str): return Text.from_markup(obj) + if isinstance(obj, float): + return f"{obj:.2f}" if not is_renderable(obj): - return None + return str(obj) return cast(RenderableType, obj) @dataclass class Column: - """Table column.""" + """Metadata for a column in the DataTable.""" + key: ColumnKey label: Text width: int = 0 - visible: bool = False - index: int = 0 - content_width: int = 0 auto_width: bool = False @@ -75,12 +178,10 @@ class Column: @dataclass class Row: - """Table row.""" + """Metadata for a row in the DataTable.""" - index: int + key: RowKey height: int - y: int - cell_renderables: list[RenderableType] = field(default_factory=list) class DataTable(ScrollView, Generic[CellType], can_focus=True): @@ -182,131 +283,183 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): zebra_stripes = Reactive(False) header_height = Reactive(1) show_cursor = Reactive(True) - cursor_type = Reactive(CELL) + cursor_type = Reactive("cell") - cursor_cell: Reactive[Coordinate] = Reactive( + cursor_coordinate: Reactive[Coordinate] = Reactive( Coordinate(0, 0), repaint=False, always_update=True ) - hover_cell: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) + hover_coordinate: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False) class CellHighlighted(Message, bubble=True): """Posted when the cursor moves to highlight a new cell. - It's only relevant when the `cursor_type` is `"cell"`. - It's also posted when the cell cursor is re-enabled (by setting `show_cursor=True`), - and when the cursor type is changed to `"cell"`. Can be handled using - `on_data_table_cell_highlighted` in a subclass of `DataTable` or in a parent - widget in the DOM. - Attributes: - value: The value in the highlighted cell. - coordinate: The coordinate of the highlighted cell. + This is only relevant when the `cursor_type` is `"cell"`. + It's also posted when the cell cursor is + re-enabled (by setting `show_cursor=True`), and when the cursor type is + changed to `"cell"`. Can be handled using `on_data_table_cell_highlighted` in + a subclass of `DataTable` or in a parent widget in the DOM. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value + """The value in the highlighted cell.""" self.coordinate: Coordinate = coordinate + """The coordinate of the highlighted cell.""" + self.cell_key: CellKey = cell_key + """The key for the highlighted cell.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class CellSelected(Message, bubble=True): """Posted by the `DataTable` widget when a cell is selected. - It's only relevant when the `cursor_type` is `"cell"`. Can be handled using + + This is only relevant when the `cursor_type` is `"cell"`. Can be handled using `on_data_table_cell_selected` in a subclass of `DataTable` or in a parent widget in the DOM. - - Attributes: - value: The value in the cell that was selected. - coordinate: The coordinate of the cell that was selected. """ def __init__( - self, sender: DataTable, value: CellType, coordinate: Coordinate + self, + sender: DataTable, + value: CellType, + coordinate: Coordinate, + cell_key: CellKey, ) -> None: self.value: CellType = value + """The value in the cell that was selected.""" self.coordinate: Coordinate = coordinate + """The coordinate of the cell that was selected.""" + self.cell_key: CellKey = cell_key + """The key for the selected cell.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "value", self.value yield "coordinate", self.coordinate + yield "cell_key", self.cell_key class RowHighlighted(Message, bubble=True): - """Posted when a row is highlighted. This message is only posted when the - `cursor_type` is set to `"row"`. Can be handled using `on_data_table_row_highlighted` - in a subclass of `DataTable` or in a parent widget in the DOM. + """Posted when a row is highlighted. - Attributes: - cursor_row: The y-coordinate of the cursor that highlighted the row. + This message is only posted when the + `cursor_type` is set to `"row"`. Can be handled using + `on_data_table_row_highlighted` in a subclass of `DataTable` or in a parent + widget in the DOM. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + """The y-coordinate of the cursor that highlighted the row.""" + self.row_key: RowKey = row_key + """The key of the row that was highlighted.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class RowSelected(Message, bubble=True): - """Posted when a row is selected. This message is only posted when the + """Posted when a row is selected. + + This message is only posted when the `cursor_type` is set to `"row"`. Can be handled using `on_data_table_row_selected` in a subclass of `DataTable` or in a parent widget in the DOM. - - Attributes: - cursor_row: The y-coordinate of the cursor that made the selection. """ - def __init__(self, sender: DataTable, cursor_row: int) -> None: + def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None: self.cursor_row: int = cursor_row + """The y-coordinate of the cursor that made the selection.""" + self.row_key: RowKey = row_key + """The key of the row that was selected.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_row", self.cursor_row + yield "row_key", self.row_key class ColumnHighlighted(Message, bubble=True): - """Posted when a column is highlighted. This message is only posted when the + """Posted when a column is highlighted. + + This message is only posted when the `cursor_type` is set to `"column"`. Can be handled using `on_data_table_column_highlighted` in a subclass of `DataTable` or in a parent widget in the DOM. - - Attributes: - cursor_column: The x-coordinate of the column that was highlighted. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + """The x-coordinate of the column that was highlighted.""" + self.column_key = column_key + """The key of the column that was highlighted.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key class ColumnSelected(Message, bubble=True): - """Posted when a column is selected. This message is only posted when the + """Posted when a column is selected. + + This message is only posted when the `cursor_type` is set to `"column"`. Can be handled using `on_data_table_column_selected` in a subclass of `DataTable` or in a parent widget in the DOM. - - Attributes: - cursor_column: The x-coordinate of the column that was selected. """ - def __init__(self, sender: DataTable, cursor_column: int) -> None: + def __init__( + self, sender: DataTable, cursor_column: int, column_key: ColumnKey + ) -> None: self.cursor_column: int = cursor_column + """The x-coordinate of the column that was selected.""" + self.column_key = column_key + """The key of the column that was selected.""" super().__init__(sender) def __rich_repr__(self) -> rich.repr.Result: yield "sender", self.sender yield "cursor_column", self.cursor_column + yield "column_key", self.column_key + + class HeaderSelected(Message, bubble=True): + """Posted when a column header/label is clicked.""" + + def __init__( + self, + sender: DataTable, + column_key: ColumnKey, + column_index: int, + label: Text, + ): + self.column_key = column_key + """The key for the column.""" + self.column_index = column_index + """The index for the column.""" + self.label = label + """The text of the label.""" + super().__init__(sender) + + def __rich_repr__(self) -> rich.repr.Result: + yield "sender", self.sender + yield "column_key", self.column_key + yield "label", self.label.plain def __init__( self, @@ -322,69 +475,280 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): classes: str | None = None, ) -> None: super().__init__(name=name, id=id, classes=classes) + self._data: dict[RowKey, dict[ColumnKey, CellType]] = {} + """Contains the cells of the table, indexed by row key and column key. + The final positioning of a cell on screen cannot be determined solely by this + structure. Instead, we must check _row_locations and _column_locations to find + where each cell currently resides in space.""" + + self.columns: dict[ColumnKey, Column] = {} + """Metadata about the columns of the table, indexed by their key.""" + self.rows: dict[RowKey, Row] = {} + """Metadata about the rows of the table, indexed by their key.""" + + # Keep tracking of key -> index for rows/cols. These allow us to retrieve, + # given a row or column key, the index that row or column is currently + # present at, and mean that rows and columns are location independent - they + # can move around without requiring us to modify the underlying data. + self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) + """Maps row keys to row indices which represent row order.""" + self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({}) + """Maps column keys to column indices which represent column order.""" - self.columns: list[Column] = [] - self.rows: dict[int, Row] = {} - self.data: dict[int, list[CellType]] = {} - self.row_count = 0 - self._y_offsets: list[tuple[int, int]] = [] self._row_render_cache: LRUCache[ - tuple[int, int, Style, int, int], tuple[SegmentLines, SegmentLines] - ] - self._row_render_cache = LRUCache(1000) - self._cell_render_cache: LRUCache[ - tuple[int, int, Style, bool, bool], SegmentLines - ] - self._cell_render_cache = LRUCache(10000) - self._line_cache: LRUCache[tuple[int, int, int, int, int, int, Style], Strip] - self._line_cache = LRUCache(1000) + RowCacheKey, tuple[SegmentLines, SegmentLines] + ] = LRUCache(1000) + """For each row (a row can have a height of multiple lines), we maintain a + cache of the fixed and scrollable lines within that row to minimise how often + we need to re-render it. """ + self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000) + """Cache for individual cells.""" + self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000) + """Cache for lines within rows.""" + self._offset_cache: LRUCache[int, list[tuple[RowKey, int]]] = LRUCache(1) + """Cached y_offset - key is update_count - see y_offsets property for more + information """ + self._ordered_row_cache: LRUCache[tuple[int, int], list[Row]] = LRUCache(1) + """Caches row ordering - key is (num_rows, update_count).""" - self._line_no = 0 self._require_update_dimensions: bool = False - self._new_rows: set[int] = set() + """Set to re-calculate dimensions on idle.""" + self._new_rows: set[RowKey] = set() + """Tracking newly added rows to be used in calculation of dimensions on idle.""" + self._updated_cells: set[CellKey] = set() + """Track which cells were updated, so that we can refresh them once on idle.""" self.show_header = show_header - self.fixed_rows = fixed_rows - self.fixed_columns = fixed_columns - self.zebra_stripes = zebra_stripes + """Show/hide the header row (the row of column labels).""" self.header_height = header_height + """The height of the header row (the row of column labels).""" + self.fixed_rows = fixed_rows + """The number of rows to fix (prevented from scrolling).""" + self.fixed_columns = fixed_columns + """The number of columns to fix (prevented from scrolling).""" + self.zebra_stripes = zebra_stripes + """Apply zebra effect on row backgrounds (light, dark, light, dark, ...).""" self.show_cursor = show_cursor + """Show/hide both the keyboard and hover cursor.""" self._show_hover_cursor = False + """Used to hide the mouse hover cursor when the user uses the keyboard.""" + self._update_count = 0 + """Number of update (INCLUDING SORT) operations so far. Used for cache invalidation.""" + self._header_row_key = RowKey() + """The header is a special row - not part of the data. Retrieve via this key.""" @property def hover_row(self) -> int: - return self.hover_cell.row + """The index of the row that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.row @property def hover_column(self) -> int: - return self.hover_cell.column + """The index of the column that the mouse cursor is currently hovering above.""" + return self.hover_coordinate.column @property def cursor_row(self) -> int: - return self.cursor_cell.row + """The index of the row that the DataTable cursor is currently on.""" + return self.cursor_coordinate.row @property def cursor_column(self) -> int: - return self.cursor_cell.column + """The index of the column that the DataTable cursor is currently on.""" + return self.cursor_coordinate.column - def get_cell_value(self, coordinate: Coordinate) -> CellType: - """Get the value from the cell at the given coordinate. + @property + def row_count(self) -> int: + """The number of rows currently present in the DataTable.""" + return len(self.rows) + + @property + def _y_offsets(self) -> list[tuple[RowKey, int]]: + """Contains a 2-tuple for each line (not row!) of the DataTable. Given a + y-coordinate, we can index into this list to find which row that y-coordinate + lands on, and the y-offset *within* that row. The length of the returned list + is therefore the total height of all rows within the DataTable.""" + y_offsets = [] + if self._update_count in self._offset_cache: + y_offsets = self._offset_cache[self._update_count] + else: + for row in self.ordered_rows: + y_offsets += [(row.key, y) for y in range(row.height)] + self._offset_cache = y_offsets + return y_offsets + + @property + def _total_row_height(self) -> int: + """The total height of all rows within the DataTable""" + return len(self._y_offsets) + + def update_cell( + self, + row_key: RowKey | str, + column_key: ColumnKey | str, + value: CellType, + *, + update_width: bool = False, + ) -> None: + """Update the cell identified by the specified row key and column key. + + Args: + row_key: The key identifying the row. + column_key: The key identifying the column. + value: The new value to put inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + + Raises: + CellDoesNotExist: When the supplied `row_key` and `column_key` + cannot be found in the table. + """ + if isinstance(row_key, str): + row_key = RowKey(row_key) + if isinstance(column_key, str): + column_key = ColumnKey(column_key) + + try: + self._data[row_key][column_key] = value + except KeyError: + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) from None + self._update_count += 1 + + # Recalculate widths if necessary + if update_width: + self._updated_cells.add(CellKey(row_key, column_key)) + self._require_update_dimensions = True + + self.refresh() + + def update_cell_at( + self, coordinate: Coordinate, value: CellType, *, update_width: bool = False + ) -> None: + """Update the content inside the cell currently occupying the given coordinate. + + Args: + coordinate: The coordinate to update the cell at. + value: The new value to place inside the cell. + update_width: Whether to resize the column width to accommodate + for the new cell content. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist(f"Coordinate {coordinate!r} is invalid.") + + row_key, column_key = self.coordinate_to_cell_key(coordinate) + self.update_cell(row_key, column_key, value, update_width=update_width) + + def get_cell(self, row_key: RowKey, column_key: ColumnKey) -> CellType: + """Given a row key and column key, return the value of the corresponding cell. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The value of the cell identified by the row and column keys. + """ + try: + cell_value = self._data[row_key][column_key] + except KeyError: + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) + return cell_value + + def get_cell_at(self, coordinate: Coordinate) -> CellType: + """Get the value from the cell occupying the given coordinate. Args: coordinate: The coordinate to retrieve the value from. Returns: - The value of the cell. + The value of the cell at the coordinate. Raises: CellDoesNotExist: If there is no cell with the given coordinate. """ - row, column = coordinate - try: - cell_value = self.data[row][column] - except KeyError: - raise CellDoesNotExist(f"No cell exists at {coordinate!r}") from None - return cell_value + row_key, column_key = self.coordinate_to_cell_key(coordinate) + return self.get_cell(row_key, column_key) + + def get_row(self, row_key: RowKey | str) -> list[CellType]: + """Get the values from the row identified by the given row key. + + Args: + row_key: The key of the row. + + Returns: + A list of the values contained within the row. + + Raises: + RowDoesNotExist: When there is no row corresponding to the key. + """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") + cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {}) + ordered_row: list[CellType] = [ + cell_mapping[column.key] for column in self.ordered_columns + ] + return ordered_row + + def get_row_at(self, row_index: int) -> list[CellType]: + """Get the values from the cells in a row at a given index. This will + return the values from a row based on the rows _current position_ in + the table. + + Args: + row_index: The index of the row. + + Returns: + A list of the values contained in the row. + + Raises: + RowDoesNotExist: If there is no row with the given index. + """ + if not self.is_valid_row_index(row_index): + raise RowDoesNotExist(f"Row index {row_index!r} is not valid.") + row_key = self._row_locations.get_key(row_index) + return self.get_row(row_key) + + def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: + """Get the values from the column identified by the given column key. + + Args: + column_key: The key of the column. + + Returns: + A generator which yields the cells in the column. + + Raises: + ColumnDoesNotExist: If there is no column corresponding to the key. + """ + if column_key not in self._column_locations: + raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") + + data = self._data + for row_metadata in self.ordered_rows: + row_key = row_metadata.key + yield data[row_key][column_key] + + def get_column_at(self, column_index: int) -> Iterable[CellType]: + """Get the values from the column at a given index. + + Args: + column_index: The index of the column. + + Returns: + A generator which yields the cells in the column. + + Raises: + ColumnDoesNotExist: If there is no column with the given index. + """ + if not self.is_valid_column_index(column_index): + raise ColumnDoesNotExist(f"Column index {column_index!r} is not valid.") + + column_key = self._column_locations.get_key(column_index) + yield from self.get_column(column_key) def _clear_caches(self) -> None: self._row_render_cache.clear() @@ -392,12 +756,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._line_cache.clear() self._styles_cache.clear() - def get_row_height(self, row_index: int) -> int: - if row_index == -1: - return self.header_height - return self.rows[row_index].height + def get_row_height(self, row_key: RowKey) -> int: + """Given a row key, return the height of that row in terminal cells. - async def on_styles_updated(self, message: messages.StylesUpdated) -> None: + Args: + row_key: The key of the row. + + Returns: + The height of the row, measured in terminal character cells. + """ + if row_key is self._header_row_key: + return self.header_height + return self.rows[row_key].height + + async def on_styles_updated(self) -> None: self._clear_caches() self.refresh() @@ -408,34 +780,34 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # post the appropriate [Row|Column|Cell]Highlighted event. self._scroll_cursor_into_view(animate=False) if self.cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_coordinate(self.cursor_coordinate) elif self.cursor_type == "row": self._highlight_row(self.cursor_row) elif self.cursor_type == "column": self._highlight_column(self.cursor_column) - def watch_show_header(self, show_header: bool) -> None: + def watch_show_header(self) -> None: self._clear_caches() - def watch_fixed_rows(self, fixed_rows: int) -> None: + def watch_fixed_rows(self) -> None: self._clear_caches() - def watch_zebra_stripes(self, zebra_stripes: bool) -> None: + def watch_zebra_stripes(self) -> None: self._clear_caches() - def watch_hover_cell(self, old: Coordinate, value: Coordinate) -> None: - self.refresh_cell(*old) - self.refresh_cell(*value) + def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None: + self.refresh_coordinate(old) + self.refresh_coordinate(value) - def watch_cursor_cell( + def watch_cursor_coordinate( self, old_coordinate: Coordinate, new_coordinate: Coordinate ) -> None: if old_coordinate != new_coordinate: # Refresh the old and the new cell, and post the appropriate # message to tell users of the newly highlighted row/cell/column. if self.cursor_type == "cell": - self.refresh_cell(*old_coordinate) - self._highlight_cell(new_coordinate) + self.refresh_coordinate(old_coordinate) + self._highlight_coordinate(new_coordinate) elif self.cursor_type == "row": self.refresh_row(old_coordinate.row) self._highlight_row(new_coordinate.row) @@ -443,37 +815,67 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self.refresh_column(old_coordinate.column) self._highlight_column(new_coordinate.column) - def _highlight_cell(self, coordinate: Coordinate) -> None: + def _highlight_coordinate(self, coordinate: Coordinate) -> None: """Apply highlighting to the cell at the coordinate, and post event.""" - self.refresh_cell(*coordinate) + self.refresh_coordinate(coordinate) try: - cell_value = self.get_cell_value(coordinate) + cell_value = self.get_cell_at(coordinate) except CellDoesNotExist: # The cell may not exist e.g. when the table is cleared. # In that case, there's nothing for us to do here. return else: + cell_key = self.coordinate_to_cell_key(coordinate) self.post_message_no_wait( - DataTable.CellHighlighted(self, cell_value, coordinate) + DataTable.CellHighlighted( + self, cell_value, coordinate=coordinate, cell_key=cell_key + ) ) + def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey: + """Return the key for the cell currently occupying this coordinate. + + Args: + coordinate: The coordinate to exam the current cell key of. + + Returns: + The key of the cell currently occupying this coordinate. + + Raises: + CellDoesNotExist: If the coordinate is not valid. + """ + if not self.is_valid_coordinate(coordinate): + raise CellDoesNotExist(f"No cell exists at {coordinate!r}.") + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + column_key = self._column_locations.get_key(column_index) + return CellKey(row_key, column_key) + def _highlight_row(self, row_index: int) -> None: """Apply highlighting to the row at the given index, and post event.""" self.refresh_row(row_index) - if row_index in self.data: - self.post_message_no_wait(DataTable.RowHighlighted(self, row_index)) + is_valid_row = row_index < len(self._data) + if is_valid_row: + row_key = self._row_locations.get_key(row_index) + self.post_message_no_wait( + DataTable.RowHighlighted(self, row_index, row_key) + ) def _highlight_column(self, column_index: int) -> None: """Apply highlighting to the column at the given index, and post event.""" self.refresh_column(column_index) if column_index < len(self.columns): - self.post_message_no_wait(DataTable.ColumnHighlighted(self, column_index)) + column_key = self._column_locations.get_key(column_index) + self.post_message_no_wait( + DataTable.ColumnHighlighted(self, column_index, column_key) + ) - def validate_cursor_cell(self, value: Coordinate) -> Coordinate: - return self._clamp_cursor_cell(value) + def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate: + return self._clamp_cursor_coordinate(value) - def _clamp_cursor_cell(self, cursor_cell: Coordinate) -> Coordinate: - row, column = cursor_cell + def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate: + """Clamp a coordinate such that it falls within the boundaries of the table.""" + row, column = coordinate row = clamp(row, 0, self.row_count - 1) column = clamp(column, self.fixed_columns, len(self.columns) - 1) return Coordinate(row, column) @@ -485,53 +887,89 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): # Refresh cells that were previously impacted by the cursor # but may no longer be. - row_index, column_index = self.cursor_cell if old == "cell": - self.refresh_cell(row_index, column_index) + self.refresh_coordinate(self.cursor_coordinate) elif old == "row": + row_index, _ = self.cursor_coordinate self.refresh_row(row_index) elif old == "column": + _, column_index = self.cursor_coordinate self.refresh_column(column_index) self._scroll_cursor_into_view() def _highlight_cursor(self) -> None: - row_index, column_index = self.cursor_cell + """Applies the appropriate highlighting and raises the appropriate + [Row|Column|Cell]Highlighted event for the given cursor coordinate + and cursor type.""" + row_index, column_index = self.cursor_coordinate cursor_type = self.cursor_type # Apply the highlighting to the newly relevant cells if cursor_type == "cell": - self._highlight_cell(self.cursor_cell) + self._highlight_coordinate(self.cursor_coordinate) elif cursor_type == "row": self._highlight_row(row_index) elif cursor_type == "column": self._highlight_column(column_index) - def _update_dimensions(self, new_rows: Iterable[int]) -> None: + def _update_column_widths(self, updated_cells: set[CellKey]) -> None: + """Update the widths of the columns based on the newly updated cell widths.""" + for row_key, column_key in updated_cells: + column = self.columns.get(column_key) + if column is None: + continue + console = self.app.console + label_width = measure(console, column.label, 1) + content_width = column.content_width + cell_value = self._data[row_key][column_key] + + new_content_width = measure(console, default_cell_formatter(cell_value), 1) + + if new_content_width < content_width: + cells_in_column = self.get_column(column_key) + cell_widths = [ + measure(console, default_cell_formatter(cell), 1) + for cell in cells_in_column + ] + column.content_width = max([*cell_widths, label_width]) + else: + column.content_width = max(new_content_width, label_width) + + def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None: """Called to recalculate the virtual (scrollable) size.""" - for row_index in new_rows: + for row_key in new_rows: + row_index = self._row_locations.get(row_key) + if row_index is None: + continue for column, renderable in zip( - self.columns, self._get_row_renderables(row_index) + self.ordered_columns, self._get_row_renderables(row_index) ): content_width = measure(self.app.console, renderable, 1) column.content_width = max(column.content_width, content_width) self._clear_caches() - total_width = sum(column.render_width for column in self.columns) + total_width = sum(column.render_width for column in self.columns.values()) header_height = self.header_height if self.show_header else 0 self.virtual_size = Size( total_width, - len(self._y_offsets) + header_height, + self._total_row_height + header_height, ) - def _get_cell_region(self, row_index: int, column_index: int) -> Region: - """Get the region of the cell at the given coordinate (row_index, column_index)""" - if row_index not in self.rows: + def _get_cell_region(self, coordinate: Coordinate) -> Region: + """Get the region of the cell at the given spatial coordinate.""" + if not self.is_valid_coordinate(coordinate): return Region(0, 0, 0, 0) - row = self.rows[row_index] - x = sum(column.render_width for column in self.columns[:column_index]) - width = self.columns[column_index].render_width + + row_index, column_index = coordinate + row_key = self._row_locations.get_key(row_index) + row = self.rows[row_key] + + # The x-coordinate of a cell is the sum of widths of cells to the left. + x = sum(column.render_width for column in self.ordered_columns[:column_index]) + column_key = self._column_locations.get_key(column_index) + width = self.columns[column_key].render_width height = row.height - y = row.y + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height cell_region = Region(x, y, width, height) @@ -539,12 +977,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_row_region(self, row_index: int) -> Region: """Get the region of the row at the given index.""" - rows = self.rows - if row_index < 0 or row_index >= len(rows): + if not self.is_valid_row_index(row_index): return Region(0, 0, 0, 0) - row = rows[row_index] - row_width = sum(column.render_width for column in self.columns) - y = row.y + + rows = self.rows + row_key = self._row_locations.get_key(row_index) + row = rows[row_key] + row_width = sum(column.render_width for column in self.columns.values()) + y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) if self.show_header: y += self.header_height row_region = Region(0, y, row_width, row.height) @@ -552,14 +992,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _get_column_region(self, column_index: int) -> Region: """Get the region of the column at the given index.""" - columns = self.columns - if column_index < 0 or column_index >= len(columns): + if not self.is_valid_column_index(column_index): return Region(0, 0, 0, 0) - x = sum(column.render_width for column in self.columns[:column_index]) - width = columns[column_index].render_width + columns = self.columns + x = sum(column.render_width for column in self.ordered_columns[:column_index]) + column_key = self._column_locations.get_key(column_index) + width = columns[column_key].render_width header_height = self.header_height if self.show_header else 0 - height = len(self._y_offsets) + header_height + height = self._total_row_height + header_height full_column_region = Region(x, 0, width, height) return full_column_region @@ -569,77 +1010,97 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: columns: Also clear the columns. Defaults to False. """ - self.row_count = 0 self._clear_caches() self._y_offsets.clear() - self.data.clear() + self._data.clear() self.rows.clear() if columns: self.columns.clear() - self._line_no = 0 self._require_update_dimensions = True - self.cursor_cell = Coordinate(0, 0) - self.hover_cell = Coordinate(0, 0) + self.cursor_coordinate = Coordinate(0, 0) + self.hover_coordinate = Coordinate(0, 0) self.refresh() - def add_columns(self, *labels: TextType) -> None: - """Add a number of columns. - - Args: - *labels: Column headers. - """ - for label in labels: - self.add_column(label, width=None) - - def add_column(self, label: TextType, *, width: int | None = None) -> None: + def add_column( + self, label: TextType, *, width: int | None = None, key: str | None = None + ) -> ColumnKey: """Add a column to the table. Args: label: A str or Text object containing the label (shown top of column). - width: Width of the column in cells or None to fit content. Defaults to None. - """ - text_label = Text.from_markup(label) if isinstance(label, str) else label + width: Width of the column in cells or None to fit content. + key: A key which uniquely identifies this column. + If None, it will be generated for you. - content_width = measure(self.app.console, text_label, 1) + Returns: + Uniquely identifies this column. Can be used to retrieve this column + regardless of its current location in the DataTable (it could have moved + after being added due to sorting/insertion/deletion of other columns). + """ + column_key = ColumnKey(key) + if column_key in self._column_locations: + raise DuplicateKey(f"The column key {key!r} already exists.") + column_index = len(self.columns) + label = Text.from_markup(label) if isinstance(label, str) else label + content_width = measure(self.app.console, label, 1) if width is None: column = Column( - text_label, + column_key, + label, content_width, - index=len(self.columns), content_width=content_width, auto_width=True, ) else: column = Column( - text_label, width, content_width=content_width, index=len(self.columns) + column_key, + label, + width, + content_width=content_width, ) - self.columns.append(column) + self.columns[column_key] = column + self._column_locations[column_key] = column_index self._require_update_dimensions = True self.check_idle() - def add_row(self, *cells: CellType, height: int = 1) -> None: - """Add a row. + return column_key + + def add_row( + self, *cells: CellType, height: int = 1, key: str | None = None + ) -> RowKey: + """Add a row at the bottom of the DataTable. Args: *cells: Positional arguments should contain cell data. - height: The height of a row (in lines). Defaults to 1. + height: The height of a row (in lines). + key: A key which uniquely identifies this row. If None, it will be generated + for you and returned. + + Returns: + Uniquely identifies this row. Can be used to retrieve this row regardless + of its current location in the DataTable (it could have moved after + being added due to sorting or insertion/deletion of other rows). """ + row_key = RowKey(key) + if row_key in self._row_locations: + raise DuplicateKey(f"The row key {row_key!r} already exists.") + + # TODO: If there are no columns: do we generate them here? + # If we don't do this, users will be required to call add_column(s) + # Before they call add_row. + row_index = self.row_count - - self.data[row_index] = list(cells) - self.rows[row_index] = Row(row_index, height, self._line_no) - - for line_no in range(height): - self._y_offsets.append((row_index, line_no)) - - self.row_count += 1 - self._line_no += height - - self._new_rows.add(row_index) + # Map the key of this row to its current index + self._row_locations[row_key] = row_index + self._data[row_key] = { + column.key: cell + for column, cell in zip_longest(self.ordered_columns, cells) + } + self.rows[row_key] = Row(row_key, height) + self._new_rows.add(row_key) self._require_update_dimensions = True - self.cursor_cell = self.cursor_cell - self.check_idle() + self.cursor_coordinate = self.cursor_coordinate # If a position has opened for the cursor to appear, where it previously # could not (e.g. when there's no data in the table), then a highlighted @@ -650,33 +1111,74 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if cell_now_available and visible_cursor: self._highlight_cursor() - def add_rows(self, rows: Iterable[Iterable[CellType]]) -> None: - """Add a number of rows. + self.check_idle() + return row_key + + def add_columns(self, *labels: TextType) -> list[ColumnKey]: + """Add a number of columns. + + Args: + *labels: Column headers. + + Returns: + A list of the keys for the columns that were added. See + the `add_column` method docstring for more information on how + these keys are used. + """ + column_keys = [] + for label in labels: + column_key = self.add_column(label, width=None) + column_keys.append(column_key) + return column_keys + + def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]: + """Add a number of rows at the bottom of the DataTable. Args: rows: Iterable of rows. A row is an iterable of cells. + + Returns: + A list of the keys for the rows that were added. See + the `add_row` method docstring for more information on how + these keys are used. """ + row_keys = [] for row in rows: - self.add_row(*row) + row_key = self.add_row(*row) + row_keys.append(row_key) + return row_keys def on_idle(self) -> None: + """Runs when the message pump is empty. + + We use this for some expensive calculations like re-computing dimensions of the + whole DataTable and re-computing column widths after some cells + have been updated. This is more efficient in the case of high + frequency updates, ensuring we only do expensive computations once.""" if self._require_update_dimensions: + # Add the new rows *before* updating the column widths, since + # cells in a new row may influence the final width of a column self._require_update_dimensions = False new_rows = self._new_rows.copy() self._new_rows.clear() self._update_dimensions(new_rows) - self.refresh() - def refresh_cell(self, row_index: int, column_index: int) -> None: - """Refresh a cell. + if self._updated_cells: + # Cell contents have already been updated at this point. + # Now we only need to worry about measuring column widths. + updated_columns = self._updated_cells.copy() + self._updated_cells.clear() + self._update_column_widths(updated_columns) + + def refresh_coordinate(self, coordinate: Coordinate) -> None: + """Refresh the cell at a coordinate. Args: - row_index: Row index. - column_index: Column index. + coordinate: The coordinate to refresh. """ - if row_index < 0 or column_index < 0: + if not self.is_valid_coordinate(coordinate): return - region = self._get_cell_region(row_index, column_index) + region = self._get_cell_region(coordinate) self._refresh_region(region) def refresh_row(self, row_index: int) -> None: @@ -685,7 +1187,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: row_index: The index of the row to refresh. """ - if row_index < 0 or row_index >= len(self.rows): + if not self.is_valid_row_index(row_index): return region = self._get_row_region(row_index) @@ -697,7 +1199,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Args: column_index: The index of the column to refresh. """ - if column_index < 0 or column_index >= len(self.columns): + if not self.is_valid_column_index(column_index): return region = self._get_column_region(column_index) @@ -712,8 +1214,72 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): region = region.translate(-self.scroll_offset) self.refresh(region) + def is_valid_row_index(self, row_index: int) -> bool: + """Return a boolean indicating whether the row_index is within table bounds. + + Args: + row_index: The row index to check. + + Returns: + True if the row index is within the bounds of the table. + """ + return 0 <= row_index < len(self.rows) + + def is_valid_column_index(self, column_index: int) -> bool: + """Return a boolean indicating whether the column_index is within table bounds. + + Args: + column_index: The column index to check. + + Returns: + True if the column index is within the bounds of the table. + """ + return 0 <= column_index < len(self.columns) + + def is_valid_coordinate(self, coordinate: Coordinate) -> bool: + """Return a boolean indicating whether the given coordinate is valid. + + Args: + coordinate: The coordinate to validate. + + Returns: + True if the coordinate is within the bounds of the table. + """ + row_index, column_index = coordinate + return self.is_valid_row_index(row_index) and self.is_valid_column_index( + column_index + ) + + @property + def ordered_columns(self) -> list[Column]: + """The list of Columns in the DataTable, ordered as they appear on screen.""" + column_indices = range(len(self.columns)) + column_keys = [ + self._column_locations.get_key(index) for index in column_indices + ] + ordered_columns = [self.columns[key] for key in column_keys] + return ordered_columns + + @property + def ordered_rows(self) -> list[Row]: + """The list of Rows in the DataTable, ordered as they appear on screen.""" + num_rows = self.row_count + update_count = self._update_count + cache_key = (num_rows, update_count) + if cache_key in self._ordered_row_cache: + ordered_rows = self._ordered_row_cache[cache_key] + else: + row_indices = range(num_rows) + ordered_rows = [] + for row_index in row_indices: + row_key = self._row_locations.get_key(row_index) + row = self.rows[row_key] + ordered_rows.append(row) + self._ordered_row_cache[cache_key] = ordered_rows + return ordered_rows + def _get_row_renderables(self, row_index: int) -> list[RenderableType]: - """Get renderables for the given row. + """Get renderables for the row currently at the given row index. Args: row_index: Index of the row. @@ -721,20 +1287,17 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): Returns: List of renderables """ - + ordered_columns = self.ordered_columns if row_index == -1: - row = [column.label for column in self.columns] + row: list[RenderableType] = [column.label for column in ordered_columns] return row - data = self.data.get(row_index) + ordered_row = self.get_row_at(row_index) empty = Text() - if data is None: - return [empty for _ in self.columns] - else: - return [ - Text() if datum is None else default_cell_formatter(datum) or empty - for datum, _ in zip_longest(data, range(len(self.columns))) - ] + return [ + Text() if datum is None else default_cell_formatter(datum) or empty + for datum, _ in zip_longest(ordered_row, range(len(self.columns))) + ] def _render_cell( self, @@ -767,9 +1330,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if hover and show_cursor and self._show_hover_cursor: style += self.get_component_styles("datatable--highlight").rich_style if is_fixed_style: - # Apply subtle variation in style for the fixed (blue background by default) - # rows and columns affected by the cursor, to ensure we can still differentiate - # between the labels and the data. + # Apply subtle variation in style for the fixed (blue background by + # default) rows and columns affected by the cursor, to ensure we can + # still differentiate between the labels and the data. style += self.get_component_styles( "datatable--highlight-fixed" ).rich_style @@ -779,34 +1342,39 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): if is_fixed_style: style += self.get_component_styles("datatable--cursor-fixed").rich_style - cell_key = (row_index, column_index, style, cursor, hover) - if cell_key not in self._cell_render_cache: + if is_header_row: + row_key = self._header_row_key + else: + row_key = self._row_locations.get_key(row_index) + + column_key = self._column_locations.get_key(column_index) + cell_cache_key = (row_key, column_key, style, cursor, hover, self._update_count) + if cell_cache_key not in self._cell_render_cache: style += Style.from_meta({"row": row_index, "column": column_index}) - height = ( - self.header_height if is_header_row else self.rows[row_index].height - ) + height = self.header_height if is_header_row else self.rows[row_key].height cell = self._get_row_renderables(row_index)[column_index] lines = self.app.console.render_lines( Padding(cell, (0, 1)), self.app.console.options.update_dimensions(width, height), style=style, ) - self._cell_render_cache[cell_key] = lines - return self._cell_render_cache[cell_key] + self._cell_render_cache[cell_cache_key] = lines + return self._cell_render_cache[cell_cache_key] - def _render_row( + def _render_line_in_row( self, - row_index: int, + row_key: RowKey, line_no: int, base_style: Style, cursor_location: Coordinate, hover_location: Coordinate, ) -> tuple[SegmentLines, SegmentLines]: - """Render a row in to lines for each cell. + """Render a single line from a row in the DataTable. Args: - row_index: Index of the row. - line_no: Line number (on screen, 0 is top) + row_key: The identifying key for this row. + line_no: Line number (y-coordinate) within row. 0 is the first strip of + cells in the row, line_no=1 is the next line in the row, and so on... base_style: Base style of row. cursor_location: The location of the cursor in the DataTable. hover_location: The location of the hover cursor in the DataTable. @@ -816,8 +1384,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): """ cursor_type = self.cursor_type show_cursor = self.show_cursor + cache_key = ( - row_index, + row_key, line_no, base_style, cursor_location, @@ -825,43 +1394,50 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): cursor_type, show_cursor, self._show_hover_cursor, + self._update_count, ) if cache_key in self._row_render_cache: return self._row_render_cache[cache_key] - render_cell = self._render_cell - def _should_highlight( - cursor_location: Coordinate, - cell_location: Coordinate, - cursor_type: CursorType, + cursor: Coordinate, + target_cell: Coordinate, + type_of_cursor: CursorType, ) -> bool: """Determine whether we should highlight a cell given the location of the cursor, the location of the cell, and the type of cursor that is currently active.""" - if cursor_type == "cell": - return cursor_location == cell_location - elif cursor_type == "row": - cursor_row, _ = cursor_location - cell_row, _ = cell_location + if type_of_cursor == "cell": + return cursor == target_cell + elif type_of_cursor == "row": + cursor_row, _ = cursor + cell_row, _ = target_cell return cursor_row == cell_row - elif cursor_type == "column": - _, cursor_column = cursor_location - _, cell_column = cell_location + elif type_of_cursor == "column": + _, cursor_column = cursor + _, cell_column = target_cell return cursor_column == cell_column else: return False + if row_key in self._row_locations: + row_index = self._row_locations.get(row_key) + else: + row_index = -1 + + render_cell = self._render_cell if self.fixed_columns: fixed_style = self.get_component_styles("datatable--fixed").rich_style fixed_style += Style.from_meta({"fixed": True}) fixed_row = [] - for column in self.columns[: self.fixed_columns]: - cell_location = Coordinate(row_index, column.index) + for column_index, column in enumerate( + self.ordered_columns[: self.fixed_columns] + ): + cell_location = Coordinate(row_index, column_index) fixed_cell_lines = render_cell( row_index, - column.index, + column_index, fixed_style, column.render_width, cursor=_should_highlight( @@ -873,7 +1449,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): else: fixed_row = [] - if row_index == -1: + is_header_row = row_key is self._header_row_key + if is_header_row: row_style = self.get_component_styles("datatable--header").rich_style else: if self.zebra_stripes: @@ -885,11 +1462,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): row_style = base_style scrollable_row = [] - for column in self.columns: - cell_location = Coordinate(row_index, column.index) + for column_index, column in enumerate(self.ordered_columns): + cell_location = Coordinate(row_index, column_index) cell_lines = render_cell( row_index, - column.index, + column_index, row_style, column.render_width, cursor=_should_highlight(cursor_location, cell_location, cursor_type), @@ -901,25 +1478,29 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._row_render_cache[cache_key] = row_pair return row_pair - def _get_offsets(self, y: int) -> tuple[int, int]: - """Get row number and line offset for a given line. + def _get_offsets(self, y: int) -> tuple[RowKey, int]: + """Get row key and line offset for a given line. Args: - y: Y coordinate relative to screen top. + y: Y coordinate relative to DataTable top. Returns: - Line number and line offset within cell. + Row key and line (y) offset within cell. """ + header_height = self.header_height + y_offsets = self._y_offsets if self.show_header: - if y < self.header_height: - return (-1, y) - y -= self.header_height - if y > len(self._y_offsets): + if y < header_height: + return self._header_row_key, y + y -= header_height + if y > len(y_offsets): raise LookupError("Y coord {y!r} is greater than total height") - return self._y_offsets[y] + + return y_offsets[y] def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip: - """Render a line in to a list of segments. + """Render a (possibly cropped) line in to a Strip (a list of segments + representing a horizontal line). Args: y: Y coordinate of line @@ -928,13 +1509,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): base_style: Style to apply to line. Returns: - List of segments for rendering. + The Strip which represents this cropped line. """ width = self.size.width try: - row_index, line_no = self._get_offsets(y) + row_key, y_offset_in_row = self._get_offsets(y) except LookupError: return Strip.blank(width, base_style) @@ -943,24 +1524,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x1, x2, width, - self.cursor_cell, - self.hover_cell, + self.cursor_coordinate, + self.hover_coordinate, base_style, self.cursor_type, self._show_hover_cursor, + self._update_count, ) if cache_key in self._line_cache: return self._line_cache[cache_key] - fixed, scrollable = self._render_row( - row_index, - line_no, + fixed, scrollable = self._render_line_in_row( + row_key, + y_offset_in_row, base_style, - cursor_location=self.cursor_cell, - hover_location=self.hover_cell, + cursor_location=self.cursor_coordinate, + hover_location=self.hover_coordinate, ) fixed_width = sum( - column.render_width for column in self.columns[: self.fixed_columns] + column.render_width for column in self.ordered_columns[: self.fixed_columns] ) fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] @@ -975,37 +1557,80 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def render_line(self, y: int) -> Strip: width, height = self.size scroll_x, scroll_y = self.scroll_offset - fixed_top_row_count = sum( - self.get_row_height(row_index) for row_index in range(self.fixed_rows) + + fixed_row_keys: list[RowKey] = [ + self._row_locations.get_key(row_index) + for row_index in range(self.fixed_rows) + ] + + fixed_rows_height = sum( + self.get_row_height(row_key) for row_key in fixed_row_keys ) if self.show_header: - fixed_top_row_count += self.get_row_height(-1) + fixed_rows_height += self.get_row_height(self._header_row_key) - if y >= fixed_top_row_count: + if y >= fixed_rows_height: y += scroll_y return self._render_line(y, scroll_x, scroll_x + width, self.rich_style) def on_mouse_move(self, event: events.MouseMove): + """If the hover cursor is visible, display it by extracting the row + and column metadata from the segments present in the cells.""" self._set_hover_cursor(True) meta = event.style.meta if meta and self.show_cursor and self.cursor_type != "none": try: - self.hover_cell = Coordinate(meta["row"], meta["column"]) + self.hover_coordinate = Coordinate(meta["row"], meta["column"]) except KeyError: pass def _get_fixed_offset(self) -> Spacing: + """Calculate the "fixed offset", that is the space to the top and left + that is occupied by fixed rows and columns respectively. Fixed rows and columns + are rows and columns that do not participate in scrolling.""" top = self.header_height if self.show_header else 0 top += sum( - self.rows[row_index].height + self.rows[self._row_locations.get_key(row_index)].height for row_index in range(self.fixed_rows) if row_index in self.rows ) - left = sum(column.render_width for column in self.columns[: self.fixed_columns]) + left = sum( + column.render_width for column in self.ordered_columns[: self.fixed_columns] + ) return Spacing(top, 0, 0, left) + def sort( + self, + *columns: ColumnKey | str, + reverse: bool = False, + ) -> None: + """Sort the rows in the DataTable by one or more column keys. + + Args: + columns: One or more columns to sort by the values in. + reverse: If True, the sort order will be reversed. + """ + + def sort_by_column_keys( + row: tuple[RowKey, dict[ColumnKey | str, CellType]] + ) -> Any: + _, row_data = row + result = itemgetter(*columns)(row_data) + return result + + ordered_rows = sorted( + self._data.items(), key=sort_by_column_keys, reverse=reverse + ) + self._row_locations = TwoWayDict( + {key: new_index for new_index, (key, _) in enumerate(ordered_rows)} + ) + self._update_count += 1 + self.refresh() + def _scroll_cursor_into_view(self, animate: bool = False) -> None: + """When the cursor is at a boundary of the DataTable and moves out + of view, this method handles scrolling to ensure it remains visible.""" fixed_offset = self._get_fixed_offset() top, _, _, left = fixed_offset @@ -1016,7 +1641,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): x, y, width, height = self._get_column_region(self.cursor_column) region = Region(x, int(self.scroll_y) + top, width, height - top) else: - region = self._get_cell_region(self.cursor_row, self.cursor_column) + region = self._get_cell_region(self.cursor_coordinate) self.scroll_to_region(region, animate=animate, spacing=fixed_offset) @@ -1036,24 +1661,36 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): elif cursor_type == "row": self.refresh_row(self.hover_row) elif cursor_type == "cell": - self.refresh_cell(*self.hover_cell) + self.refresh_coordinate(self.hover_coordinate) def on_click(self, event: events.Click) -> None: self._set_hover_cursor(True) - if self.show_cursor and self.cursor_type != "none": + meta = self.get_style_at(event.x, event.y).meta + if not meta: + return + + row_index = meta["row"] + column_index = meta["column"] + is_header_click = self.show_header and row_index == -1 + if is_header_click: + # Header clicks work even if cursor is off, and doesn't move the cursor. + column = self.ordered_columns[column_index] + message = DataTable.HeaderSelected( + self, column.key, column_index, label=column.label + ) + self.post_message_no_wait(message) + elif self.show_cursor and self.cursor_type != "none": # Only post selection events if there is a visible row/col/cell cursor. + self.cursor_coordinate = Coordinate(row_index, column_index) self._post_selected_message() - meta = self.get_style_at(event.x, event.y).meta - if meta: - self.cursor_cell = Coordinate(meta["row"], meta["column"]) - self._scroll_cursor_into_view(animate=True) - event.stop() + self._scroll_cursor_into_view(animate=True) + event.stop() def action_cursor_up(self) -> None: self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.up() + self.cursor_coordinate = self.cursor_coordinate.up() self._scroll_cursor_into_view() else: # If the cursor doesn't move up (e.g. column cursor can't go up), @@ -1064,7 +1701,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"): - self.cursor_cell = self.cursor_cell.down() + self.cursor_coordinate = self.cursor_coordinate.down() self._scroll_cursor_into_view() else: super().action_scroll_down() @@ -1073,7 +1710,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.right() + self.cursor_coordinate = self.cursor_coordinate.right() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_right() @@ -1082,7 +1719,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): self._set_hover_cursor(False) cursor_type = self.cursor_type if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"): - self.cursor_cell = self.cursor_cell.left() + self.cursor_coordinate = self.cursor_coordinate.left() self._scroll_cursor_into_view(animate=True) else: super().action_scroll_left() @@ -1094,19 +1731,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True): def _post_selected_message(self): """Post the appropriate message for a selection based on the `cursor_type`.""" - cursor_cell = self.cursor_cell + cursor_coordinate = self.cursor_coordinate cursor_type = self.cursor_type + cell_key = self.coordinate_to_cell_key(cursor_coordinate) if cursor_type == "cell": self.post_message_no_wait( DataTable.CellSelected( self, - self.get_cell_value(cursor_cell), - cursor_cell, + self.get_cell_at(cursor_coordinate), + coordinate=cursor_coordinate, + cell_key=cell_key, ) ) elif cursor_type == "row": - row, _ = cursor_cell - self.post_message_no_wait(DataTable.RowSelected(self, row)) + row_index, _ = cursor_coordinate + row_key, _ = cell_key + self.post_message_no_wait(DataTable.RowSelected(self, row_index, row_key)) elif cursor_type == "column": - _, column = cursor_cell - self.post_message_no_wait(DataTable.ColumnSelected(self, column)) + _, column_index = cursor_coordinate + _, column_key = cell_key + self.post_message_no_wait( + DataTable.ColumnSelected(self, column_index, column_key) + ) diff --git a/src/textual/widgets/data_table.py b/src/textual/widgets/data_table.py index d0316f387..0bb18f87f 100644 --- a/src/textual/widgets/data_table.py +++ b/src/textual/widgets/data_table.py @@ -1,5 +1,29 @@ """Make non-widget DataTable support classes available.""" -from ._data_table import Column, Row +from ._data_table import ( + CellDoesNotExist, + CellKey, + CellType, + Column, + ColumnDoesNotExist, + ColumnKey, + CursorType, + DuplicateKey, + Row, + RowDoesNotExist, + RowKey, +) -__all__ = ["Column", "Row"] +__all__ = [ + "CellDoesNotExist", + "CellKey", + "CellType", + "Column", + "ColumnDoesNotExist", + "ColumnKey", + "CursorType", + "DuplicateKey", + "Row", + "RowDoesNotExist", + "RowKey", +] diff --git a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr index 3d9516f10..6fd96bfe6 100644 --- a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr +++ b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr @@ -10173,133 +10173,133 @@ font-weight: 700; } - .terminal-121683423-matrix { + .terminal-1288566407-matrix { font-family: Fira Code, monospace; font-size: 20px; line-height: 24.4px; font-variant-east-asian: full-width; } - .terminal-121683423-title { + .terminal-1288566407-title { font-size: 18px; font-weight: bold; font-family: arial; } - .terminal-121683423-r1 { fill: #dde6ed;font-weight: bold } - .terminal-121683423-r2 { fill: #e1e1e1 } - .terminal-121683423-r3 { fill: #c5c8c6 } - .terminal-121683423-r4 { fill: #211505 } + .terminal-1288566407-r1 { fill: #dde6ed;font-weight: bold } + .terminal-1288566407-r2 { fill: #e1e1e1 } + .terminal-1288566407-r3 { fill: #c5c8c6 } + .terminal-1288566407-r4 { fill: #211505 } - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - TableApp + TableApp - - - -  lane  swimmer               country        time   -  4     Joseph Schooling      Singapore      50.39  -  2     Michael Phelps        United States  51.14  -  5     Chad le Clos          South Africa   51.14  -  6     László Cseh           Hungary        51.14  -  3     Li Zhuhao             China          51.26  -  8     Mehdy Metella         France         51.58  -  7     Tom Shields           United States  51.73  -  1     Aleksandr Sadovnikov  Russia         51.84  - - - - - - - - - - - - - - + + + +  lane  swimmer               country        time   +  4     Joseph Schooling      Singapore      50.39  +  2     Michael Phelps        United States  51.14  +  5     Chad le Clos          South Africa   51.14  +  6     László Cseh           Hungary        51.14  +  3     Li Zhuhao             China          51.26  +  8     Mehdy Metella         France         51.58  +  7     Tom Shields           United States  51.73  +  1     Aleksandr Sadovnikov  Russia         51.84  +  10    Darren Burns          Scotland       51.84  + + + + + + + + + + + + + @@ -10464,6 +10464,163 @@ ''' # --- +# name: test_datatable_sort_multikey + ''' + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + TableApp + + + + + + + + + +  lane  swimmer               country        time   +  4     Joseph Schooling      Singapore      50.39  +  2     Michael Phelps        United States  51.14  +  5     Chad le Clos          South Africa   51.14  +  6     László Cseh           Hungary        51.14  +  3     Li Zhuhao             China          51.26  +  8     Mehdy Metella         France         51.58  +  7     Tom Shields           United States  51.73  +  1     Aleksandr Sadovnikov  Russia         51.84  +  10    Darren Burns          Scotland       51.84  + + + + + + + + + + + + + + + + + + + ''' +# --- # name: test_demo ''' diff --git a/tests/snapshot_tests/snapshot_apps/data_table_sort.py b/tests/snapshot_tests/snapshot_apps/data_table_sort.py new file mode 100644 index 000000000..b4866e0ca --- /dev/null +++ b/tests/snapshot_tests/snapshot_apps/data_table_sort.py @@ -0,0 +1,44 @@ +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.widgets import DataTable + +# Shuffled around a bit to exercise sorting. +ROWS = [ + ("lane", "swimmer", "country", "time"), + (5, "Chad le Clos", "South Africa", 51.14), + (4, "Joseph Schooling", "Singapore", 50.39), + (2, "Michael Phelps", "United States", 51.14), + (6, "László Cseh", "Hungary", 51.14), + (3, "Li Zhuhao", "China", 51.26), + (8, "Mehdy Metella", "France", 51.58), + (7, "Tom Shields", "United States", 51.73), + (10, "Darren Burns", "Scotland", 51.84), + (1, "Aleksandr Sadovnikov", "Russia", 51.84), +] + + +class TableApp(App): + BINDINGS = [ + Binding("s", "sort", "Sort"), + ] + + def compose(self) -> ComposeResult: + yield DataTable() + + def on_mount(self) -> None: + table = self.query_one(DataTable) + table.focus() + rows = iter(ROWS) + column_labels = next(rows) + for column in column_labels: + table.add_column(column, key=column) + table.add_rows(rows) + + def action_sort(self): + table = self.query_one(DataTable) + table.sort("time", "lane") + + +app = TableApp() +if __name__ == "__main__": + app.run() diff --git a/tests/snapshot_tests/test_snapshots.py b/tests/snapshot_tests/test_snapshots.py index 1eed7ab77..b7a715b77 100644 --- a/tests/snapshot_tests/test_snapshots.py +++ b/tests/snapshot_tests/test_snapshots.py @@ -103,6 +103,11 @@ def test_datatable_column_cursor_render(snap_compare): assert snap_compare(SNAPSHOT_APPS_DIR / "data_table_column_cursor.py", press=press) +def test_datatable_sort_multikey(snap_compare): + press = ["down", "right", "s"] # Also checks that sort doesn't move cursor. + assert snap_compare(SNAPSHOT_APPS_DIR / "data_table_sort.py", press=press) + + def test_footer_render(snap_compare): assert snap_compare(WIDGET_EXAMPLES_DIR / "footer.py") diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 6ff0192b9..2a000332d 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,11 +1,32 @@ +from __future__ import annotations + +import pytest +from rich.style import Style +from rich.text import Text + +from textual._wait import wait_for_idle +from textual.actions import SkipAction from textual.app import App from textual.coordinate import Coordinate +from textual.events import Click, MouseMove from textual.message import Message +from textual.message_pump import MessagePump from textual.widgets import DataTable +from textual.widgets.data_table import ( + CellDoesNotExist, + CellKey, + ColumnDoesNotExist, + ColumnKey, + DuplicateKey, + Row, + RowDoesNotExist, + RowKey, +) + +ROWS = [["0/0", "0/1"], ["1/0", "1/1"], ["2/0", "2/1"]] class DataTableApp(App): - messages = [] messages_to_record = { "CellHighlighted", "CellSelected", @@ -13,8 +34,13 @@ class DataTableApp(App): "RowSelected", "ColumnHighlighted", "ColumnSelected", + "HeaderSelected", } + def __init__(self): + super().__init__() + self.messages = [] + def compose(self): table = DataTable() table.focus() @@ -23,7 +49,11 @@ class DataTableApp(App): def record_data_table_event(self, message: Message) -> None: name = message.__class__.__name__ if name in self.messages_to_record: - self.messages.append(name) + self.messages.append(message) + + @property + def message_names(self) -> list[str]: + return [message.__class__.__name__ for message in self.messages] async def _on_message(self, message: Message) -> None: await super()._on_message(message) @@ -32,68 +62,66 @@ class DataTableApp(App): async def test_datatable_message_emission(): app = DataTableApp() - messages = app.messages expected_messages = [] async with app.run_test() as pilot: table = app.query_one(DataTable) - assert messages == expected_messages + assert app.message_names == expected_messages table.add_columns("Column0", "Column1") - table.add_rows([["0/0", "0/1"], ["1/0", "1/1"], ["2/0", "2/1"]]) + table.add_rows(ROWS) # A CellHighlighted is emitted because there were no rows (and # therefore no highlighted cells), but then a row was added, and # so the cell at (0, 0) became highlighted. expected_messages.append("CellHighlighted") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages # Pressing Enter when the cursor is on a cell emits a CellSelected await pilot.press("enter") + await wait_for_idle(0) expected_messages.append("CellSelected") - await pilot.pause(2 / 100) - assert messages == expected_messages + assert app.message_names == expected_messages # Moving the cursor left and up when the cursor is at origin # emits no events, since the cursor doesn't move at all. await pilot.press("left", "up") - await pilot.pause(2 / 100) - assert messages == expected_messages + assert app.message_names == expected_messages # ROW CURSOR # Switch over to the row cursor... should emit a `RowHighlighted` table.cursor_type = "row" expected_messages.append("RowHighlighted") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages # Select the row... await pilot.press("enter") + await wait_for_idle(0) expected_messages.append("RowSelected") - await pilot.pause(2 / 100) - assert messages == expected_messages + assert app.message_names == expected_messages # COLUMN CURSOR # Switching to the column cursor emits a `ColumnHighlighted` table.cursor_type = "column" expected_messages.append("ColumnHighlighted") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages # Select the column... await pilot.press("enter") expected_messages.append("ColumnSelected") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages # NONE CURSOR # No messages get emitted at all... table.cursor_type = "none" await pilot.press("up", "down", "left", "right", "enter") - await pilot.pause(2 / 100) + await wait_for_idle(0) # No new messages since cursor not visible - assert messages == expected_messages + assert app.message_names == expected_messages # Edge case - if show_cursor is False, and the cursor type # is changed back to a visible type, then no messages should @@ -101,49 +129,798 @@ async def test_datatable_message_emission(): table.show_cursor = False table.cursor_type = "cell" await pilot.press("up", "down", "left", "right", "enter") - await pilot.pause(2 / 100) + await wait_for_idle(0) # No new messages since show_cursor = False - assert messages == expected_messages + assert app.message_names == expected_messages # Now when show_cursor is set back to True, the appropriate # message should be emitted for highlighting the cell. table.show_cursor = True expected_messages.append("CellHighlighted") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages + + # Similarly for showing the cursor again when row or column + # cursor was active before the cursor was hidden. + table.show_cursor = False + table.cursor_type = "row" + table.show_cursor = True + expected_messages.append("RowHighlighted") + await wait_for_idle(0) + assert app.message_names == expected_messages + + table.show_cursor = False + table.cursor_type = "column" + table.show_cursor = True + expected_messages.append("ColumnHighlighted") + await wait_for_idle(0) + assert app.message_names == expected_messages # Likewise, if the cursor_type is "none", and we change the # show_cursor to True, then no events should be raised since # the cursor is still not visible to the user. table.cursor_type = "none" await pilot.press("up", "down", "left", "right", "enter") - await pilot.pause(2 / 100) - assert messages == expected_messages + await wait_for_idle(0) + assert app.message_names == expected_messages + + +async def test_add_rows(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + row_keys = table.add_rows(ROWS) + # We're given a key for each row + assert len(row_keys) == len(ROWS) + assert len(row_keys) == len(table._data) + assert table.row_count == len(ROWS) + # Each key can be used to fetch a row from the DataTable + assert all(key in table._data for key in row_keys) + + +async def test_add_rows_user_defined_keys(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + key_a, key_b = table.add_columns("A", "B") + algernon_key = table.add_row(*ROWS[0], key="algernon") + table.add_row(*ROWS[1], key="charlie") + auto_key = table.add_row(*ROWS[2]) + + assert algernon_key == "algernon" + # We get a RowKey object back, but we can use our own string *or* this object + # to find the row we're looking for, they're considered equivalent for lookups. + assert isinstance(algernon_key, RowKey) + + # Ensure the data in the table is mapped as expected + first_row = {key_a: ROWS[0][0], key_b: ROWS[0][1]} + assert table._data[algernon_key] == first_row + assert table._data["algernon"] == first_row + + second_row = {key_a: ROWS[1][0], key_b: ROWS[1][1]} + assert table._data["charlie"] == second_row + + third_row = {key_a: ROWS[2][0], key_b: ROWS[2][1]} + assert table._data[auto_key] == third_row + + first_row = Row(algernon_key, height=1) + assert table.rows[algernon_key] == first_row + assert table.rows["algernon"] == first_row + + +async def test_add_row_duplicate_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A") + table.add_row("1", key="1") + with pytest.raises(DuplicateKey): + table.add_row("2", key="1") # Duplicate row key + + +async def test_add_column_duplicate_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + with pytest.raises(DuplicateKey): + table.add_column("B", key="A") # Duplicate column key + + +async def test_add_column_with_width(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column = table.add_column("ABC", width=10, key="ABC") + row = table.add_row("123") + assert table.get_cell(row, column) == "123" + assert table.columns[column].width == 10 + assert table.columns[column].render_width == 12 # 10 + (2 padding) + + +async def test_add_columns(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_keys = table.add_columns("1", "2", "3") + assert len(column_keys) == 3 + assert len(table.columns) == 3 + + +async def test_add_columns_user_defined_keys(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + key = table.add_column("Column", key="donut") + assert key == "donut" + assert key == key async def test_clear(): app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Add some data and update cursor positions table.add_column("Column0") table.add_rows([["Row0"], ["Row1"], ["Row2"]]) - table.cursor_cell = Coordinate(1, 0) - table.hover_cell = Coordinate(2, 0) + table.cursor_coordinate = Coordinate(1, 0) + table.hover_coordinate = Coordinate(2, 0) # Ensure the cursor positions are reset to origin on clear() table.clear() - assert table.cursor_cell == Coordinate(0, 0) - assert table.hover_cell == Coordinate(0, 0) + assert table.cursor_coordinate == Coordinate(0, 0) + assert table.hover_coordinate == Coordinate(0, 0) # Ensure that the table has been cleared - assert table.data == {} + assert table._data == {} assert table.rows == {} + assert table.row_count == 0 assert len(table.columns) == 1 # Clearing the columns too table.clear(columns=True) assert len(table.columns) == 0 + + +async def test_column_labels() -> None: + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("1", "2", "3") + actual_labels = [col.label.plain for col in table.columns.values()] + expected_labels = ["1", "2", "3"] + assert actual_labels == expected_labels + + +async def test_initial_column_widths() -> None: + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + foo, bar = table.add_columns("foo", "bar") + + assert table.columns[foo].width == 3 + assert table.columns[bar].width == 3 + table.add_row("Hello", "World!") + await wait_for_idle() + assert table.columns[foo].content_width == 5 + assert table.columns[bar].content_width == 6 + + table.add_row("Hello World!!!", "fo") + await wait_for_idle() + assert table.columns[foo].content_width == 14 + assert table.columns[bar].content_width == 6 + + +async def test_get_cell_returns_value_at_cell(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + assert table.get_cell("R1", "C1") == "TargetValue" + + +async def test_get_cell_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell("INVALID_ROW", "C1") + + +async def test_get_cell_invalid_column_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + with pytest.raises(CellDoesNotExist): + table.get_cell("R1", "INVALID_COLUMN") + + +async def test_get_cell_at_returns_value_at_cell(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + assert table.get_cell_at(Coordinate(0, 0)) == "0/0" + + +async def test_get_cell_at_exception(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + with pytest.raises(CellDoesNotExist): + table.get_cell_at(Coordinate(9999, 0)) + + +async def test_get_row(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + first_row = table.add_row(2, 4, 1) + second_row = table.add_row(3, 2, 1) + assert table.get_row(first_row) == [2, 4, 1] + assert table.get_row(second_row) == [3, 2, 1] + + # Even if row positions change, keys should always refer to same rows. + table.sort(b) + assert table.get_row(first_row) == [2, 4, 1] + assert table.get_row(second_row) == [3, 2, 1] + + +async def test_get_row_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(RowDoesNotExist): + table.get_row("INVALID") + + +async def test_get_row_at(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + table.add_row(2, 4, 1) + table.add_row(3, 2, 1) + assert table.get_row_at(0) == [2, 4, 1] + assert table.get_row_at(1) == [3, 2, 1] + + # If we sort, then the rows present at the indices *do* change! + table.sort(b) + + # Since we sorted on column "B", the rows at indices 0 and 1 are swapped. + assert table.get_row_at(0) == [3, 2, 1] + assert table.get_row_at(1) == [2, 4, 1] + + +@pytest.mark.parametrize("index", (-1, 2)) +async def test_get_row_at_invalid_index(index): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B", "C") + table.add_row(2, 4, 1) + table.add_row(3, 2, 1) + with pytest.raises(RowDoesNotExist): + table.get_row_at(index) + + +async def test_get_column(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b = table.add_columns("A", "B") + table.add_rows(ROWS) + cells = table.get_column(a) + assert next(cells) == ROWS[0][0] + assert next(cells) == ROWS[1][0] + assert next(cells) == ROWS[2][0] + with pytest.raises(StopIteration): + next(cells) + + +async def test_get_column_invalid_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(ColumnDoesNotExist): + list(table.get_column("INVALID")) + + +async def test_get_column_at(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + + first_column = list(table.get_column_at(0)) + assert first_column == [ROWS[0][0], ROWS[1][0], ROWS[2][0]] + + second_column = list(table.get_column_at(1)) + assert second_column == [ROWS[0][1], ROWS[1][1], ROWS[2][1]] + + +@pytest.mark.parametrize("index", [-1, 5]) +async def test_get_column_at_invalid_index(index): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(ColumnDoesNotExist): + list(table.get_column_at(index)) + + +async def test_update_cell_cell_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + table.update_cell("1", "A", "NEW_VALUE") + assert table.get_cell("1", "A") == "NEW_VALUE" + + +async def test_update_cell_cell_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("A", key="A") + table.add_row("1", key="1") + with pytest.raises(CellDoesNotExist): + table.update_cell("INVALID", "CELL", "Value") + + +async def test_update_cell_at_coordinate_exists(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_0, column_1 = table.add_columns("A", "B") + row_0, *_ = table.add_rows(ROWS) + + table.update_cell_at(Coordinate(0, 1), "newvalue") + assert table.get_cell(row_0, column_1) == "newvalue" + + +async def test_update_cell_at_coordinate_doesnt_exist(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_columns("A", "B") + table.add_rows(ROWS) + with pytest.raises(CellDoesNotExist): + table.update_cell_at(Coordinate(999, 999), "newvalue") + + +@pytest.mark.parametrize( + "label,new_value,new_content_width", + [ + # Shorter than initial cell value, larger than label => width remains same + ("A", "BB", 3), + # Larger than cell value, shorter than label => width remains that of label + ("1234567", "1234", 7), + # Shorter than cell value, shorter than label => width remains same + ("12345", "123", 5), + # Larger than cell value, larger than label => width updates to new cell value + ("12345", "123456789", 9), + ], +) +async def test_update_cell_at_column_width(label, new_value, new_content_width): + # Initial cell values are length 3. Let's update cell content and ensure + # that the width of the column is correct given the new cell content widths + # and the label of the column the cell is in. + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + key, _ = table.add_columns(label, "Column2") + table.add_rows(ROWS) + first_column = table.columns.get(key) + + table.update_cell_at(Coordinate(0, 0), new_value, update_width=True) + await wait_for_idle() + assert first_column.content_width == new_content_width + assert first_column.render_width == new_content_width + 2 + + +async def test_coordinate_to_cell_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_key, _ = table.add_columns("Column0", "Column1") + row_key = table.add_row("A", "B") + + cell_key = table.coordinate_to_cell_key(Coordinate(0, 0)) + assert cell_key == CellKey(row_key, column_key) + + +async def test_coordinate_to_cell_key_invalid_coordinate(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + with pytest.raises(CellDoesNotExist): + table.coordinate_to_cell_key(Coordinate(9999, 9999)) + + +def make_click_event(sender: MessagePump): + return Click( + sender=sender, + x=1, + y=2, + delta_x=0, + delta_y=0, + button=0, + shift=False, + meta=False, + ctrl=False, + ) + + +async def test_datatable_on_click_cell_cursor(): + """When the cell cursor is used, and we click, we emit a CellHighlighted + *and* a CellSelected message for the cell that was clicked. + Regression test for https://github.com/Textualize/textual/issues/1723""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + click = make_click_event(app) + column_key = table.add_column("ABC") + table.add_row("123") + row_key = table.add_row("456") + table.on_click(event=click) + await wait_for_idle(0) + # There's two CellHighlighted events since a cell is highlighted on initial load, + # then when we click, another cell is highlighted (and selected). + assert app.message_names == [ + "CellHighlighted", + "CellHighlighted", + "CellSelected", + ] + cell_highlighted_event: DataTable.CellHighlighted = app.messages[1] + assert cell_highlighted_event.sender is table + assert cell_highlighted_event.value == "456" + assert cell_highlighted_event.cell_key == CellKey(row_key, column_key) + assert cell_highlighted_event.coordinate == Coordinate(1, 0) + + cell_selected_event: DataTable.CellSelected = app.messages[2] + assert cell_selected_event.sender is table + assert cell_selected_event.value == "456" + assert cell_selected_event.cell_key == CellKey(row_key, column_key) + assert cell_selected_event.coordinate == Coordinate(1, 0) + + +async def test_on_click_row_cursor(): + """When the row cursor is used, and we click, we emit a RowHighlighted + *and* a RowSelected message for the row that was clicked.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.cursor_type = "row" + click = make_click_event(app) + table.add_column("ABC") + table.add_row("123") + row_key = table.add_row("456") + table.on_click(event=click) + await wait_for_idle(0) + assert app.message_names == ["RowHighlighted", "RowHighlighted", "RowSelected"] + + row_highlighted: DataTable.RowHighlighted = app.messages[1] + assert row_highlighted.sender is table + assert row_highlighted.row_key == row_key + assert row_highlighted.cursor_row == 1 + + row_selected: DataTable.RowSelected = app.messages[2] + assert row_selected.sender is table + assert row_selected.row_key == row_key + assert row_highlighted.cursor_row == 1 + + +async def test_on_click_column_cursor(): + """When the column cursor is used, and we click, we emit a ColumnHighlighted + *and* a ColumnSelected message for the column that was clicked.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.cursor_type = "column" + column_key = table.add_column("ABC") + table.add_row("123") + table.add_row("456") + click = make_click_event(app) + table.on_click(event=click) + await wait_for_idle(0) + assert app.message_names == [ + "ColumnHighlighted", + "ColumnHighlighted", + "ColumnSelected", + ] + column_highlighted: DataTable.ColumnHighlighted = app.messages[1] + assert column_highlighted.sender is table + assert column_highlighted.column_key == column_key + assert column_highlighted.cursor_column == 0 + + column_selected: DataTable.ColumnSelected = app.messages[2] + assert column_selected.sender is table + assert column_selected.column_key == column_key + assert column_highlighted.cursor_column == 0 + + +async def test_hover_coordinate(): + """Ensure that the hover_coordinate reactive is updated as expected.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("ABC") + table.add_row("123") + table.add_row("456") + assert table.hover_coordinate == Coordinate(0, 0) + + mouse_move = MouseMove( + sender=app, + x=1, + y=2, + delta_x=0, + delta_y=0, + button=0, + shift=False, + meta=False, + ctrl=False, + style=Style(meta={"row": 1, "column": 2}), + ) + table.on_mouse_move(mouse_move) + await wait_for_idle(0) + assert table.hover_coordinate == Coordinate(1, 2) + + +async def test_header_selected(): + """Ensure that a HeaderSelected event gets posted when we click + on the header in the DataTable.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column = table.add_column("number") + table.add_row(3) + click_event = Click( + sender=table, + x=3, + y=0, + delta_x=0, + delta_y=0, + button=0, + shift=False, + meta=False, + ctrl=False, + ) + table.on_click(click_event) + await wait_for_idle(0) + message: DataTable.HeaderSelected = app.messages[-1] + assert message.sender is table + assert message.label == Text("number") + assert message.column_index == 0 + assert message.column_key == column + + # Now hide the header and click in the exact same place - no additional message emitted. + table.show_header = False + table.on_click(click_event) + await wait_for_idle(0) + assert app.message_names.count("HeaderSelected") == 1 + + +async def test_sort_coordinate_and_key_access(): + """Ensure that, after sorting, that coordinates and cell keys + can still be used to retrieve the correct cell.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column = table.add_column("number") + row_three = table.add_row(3) + row_one = table.add_row(1) + row_two = table.add_row(2) + + # Items inserted in correct initial positions (before sort) + assert table.get_cell_at(Coordinate(0, 0)) == 3 + assert table.get_cell_at(Coordinate(1, 0)) == 1 + assert table.get_cell_at(Coordinate(2, 0)) == 2 + + table.sort(column) + + # The keys still refer to the same cells... + assert table.get_cell(row_one, column) == 1 + assert table.get_cell(row_two, column) == 2 + assert table.get_cell(row_three, column) == 3 + + # ...even though the values under the coordinates have changed... + assert table.get_cell_at(Coordinate(0, 0)) == 1 + assert table.get_cell_at(Coordinate(1, 0)) == 2 + assert table.get_cell_at(Coordinate(2, 0)) == 3 + + assert table.ordered_rows[0].key == row_one + assert table.ordered_rows[1].key == row_two + assert table.ordered_rows[2].key == row_three + + +async def test_sort_reverse_coordinate_and_key_access(): + """Ensure that, after sorting, that coordinates and cell keys + can still be used to retrieve the correct cell.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column = table.add_column("number") + row_three = table.add_row(3) + row_one = table.add_row(1) + row_two = table.add_row(2) + + # Items inserted in correct initial positions (before sort) + assert table.get_cell_at(Coordinate(0, 0)) == 3 + assert table.get_cell_at(Coordinate(1, 0)) == 1 + assert table.get_cell_at(Coordinate(2, 0)) == 2 + + table.sort(column, reverse=True) + + # The keys still refer to the same cells... + assert table.get_cell(row_one, column) == 1 + assert table.get_cell(row_two, column) == 2 + assert table.get_cell(row_three, column) == 3 + + # ...even though the values under the coordinates have changed... + assert table.get_cell_at(Coordinate(0, 0)) == 3 + assert table.get_cell_at(Coordinate(1, 0)) == 2 + assert table.get_cell_at(Coordinate(2, 0)) == 1 + + assert table.ordered_rows[0].key == row_three + assert table.ordered_rows[1].key == row_two + assert table.ordered_rows[2].key == row_one + + +async def test_cell_cursor_highlight_events(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + column_one_key, column_two_key = table.add_columns("A", "B") + _ = table.add_row(0, 1) + row_two_key = table.add_row(2, 3) + + # Since initial position is (0, 0), cursor doesn't move so no event posted + table.action_cursor_up() + table.action_cursor_left() + + await wait_for_idle(0) + assert table.app.message_names == [ + "CellHighlighted" + ] # Initial highlight on load + + # Move the cursor one cell down, and check the highlighted event posted + table.action_cursor_down() + await wait_for_idle(0) + assert len(table.app.messages) == 2 + latest_message: DataTable.CellHighlighted = table.app.messages[-1] + assert isinstance(latest_message, DataTable.CellHighlighted) + assert latest_message.value == 2 + assert latest_message.coordinate == Coordinate(1, 0) + assert latest_message.cell_key == CellKey(row_two_key, column_one_key) + + # Now move the cursor to the right, and check highlighted event posted + table.action_cursor_right() + await wait_for_idle(0) + assert len(table.app.messages) == 3 + latest_message = table.app.messages[-1] + assert latest_message.coordinate == Coordinate(1, 1) + assert latest_message.cell_key == CellKey(row_two_key, column_two_key) + + +async def test_row_cursor_highlight_events(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.cursor_type = "row" + table.add_columns("A", "B") + row_one_key = table.add_row(0, 1) + row_two_key = table.add_row(2, 3) + + # Since initial position is row_index=0, the following actions do nothing. + with pytest.raises(SkipAction): + table.action_cursor_up() + table.action_cursor_left() + table.action_cursor_right() + + await wait_for_idle(0) + assert table.app.message_names == ["RowHighlighted"] # Initial highlight + + # Move the row cursor from row 0 to row 1, check the highlighted event posted + table.action_cursor_down() + await wait_for_idle(0) + assert len(table.app.messages) == 2 + latest_message: DataTable.RowHighlighted = table.app.messages[-1] + assert isinstance(latest_message, DataTable.RowHighlighted) + assert latest_message.row_key == row_two_key + assert latest_message.cursor_row == 1 + + # Move the row cursor back up to row 0, check the highlighted event posted + table.action_cursor_up() + await wait_for_idle(0) + assert len(table.app.messages) == 3 + latest_message = table.app.messages[-1] + assert latest_message.row_key == row_one_key + assert latest_message.cursor_row == 0 + + +async def test_column_cursor_highlight_events(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.cursor_type = "column" + column_one_key, column_two_key = table.add_columns("A", "B") + table.add_row(0, 1) + table.add_row(2, 3) + + # Since initial position is column_index=0, the following actions do nothing. + with pytest.raises(SkipAction): + table.action_cursor_left() + table.action_cursor_up() + table.action_cursor_down() + + await wait_for_idle(0) + assert table.app.message_names == ["ColumnHighlighted"] # Initial highlight + + # Move the column cursor from column 0 to column 1, + # check the highlighted event posted + table.action_cursor_right() + await wait_for_idle(0) + assert len(table.app.messages) == 2 + latest_message: DataTable.ColumnHighlighted = table.app.messages[-1] + assert isinstance(latest_message, DataTable.ColumnHighlighted) + assert latest_message.column_key == column_two_key + assert latest_message.cursor_column == 1 + + # Move the column cursor left, back to column 0, + # check the highlighted event posted again. + table.action_cursor_left() + await wait_for_idle(0) + assert len(table.app.messages) == 3 + latest_message = table.app.messages[-1] + assert latest_message.column_key == column_one_key + assert latest_message.cursor_column == 0 + + +def test_key_equals_equivalent_string(): + text = "Hello" + key = RowKey(text) + assert key == text + assert hash(key) == hash(text) + + +def test_key_doesnt_match_non_equal_string(): + key = ColumnKey("123") + text = "laksjdlaskjd" + assert key != text + assert hash(key) != hash(text) + + +def test_key_equals_self(): + row_key = RowKey() + column_key = ColumnKey() + assert row_key == row_key + assert column_key == column_key + assert row_key != column_key + + +def test_key_string_lookup(): + # Indirectly covered by other tests, but let's explicitly document + # in tests how we intend for the keys to work for cache lookups. + dictionary = { + "foo": "bar", + RowKey("hello"): "world", + } + assert dictionary["foo"] == "bar" + assert dictionary[RowKey("foo")] == "bar" + assert dictionary["hello"] == "world" + assert dictionary[RowKey("hello")] == "world" diff --git a/tests/test_strip.py b/tests/test_strip.py index 82d7ec680..299a152d0 100644 --- a/tests/test_strip.py +++ b/tests/test_strip.py @@ -2,8 +2,8 @@ import pytest from rich.segment import Segment from rich.style import Style -from textual._filter import Monochrome from textual._segment_tools import NoCellPositionForIndex +from textual.filter import Monochrome from textual.strip import Strip diff --git a/tests/test_table.py b/tests/test_table.py deleted file mode 100644 index 1290933d4..000000000 --- a/tests/test_table.py +++ /dev/null @@ -1,66 +0,0 @@ -import asyncio - -from rich.text import Text - -from textual.app import App, ComposeResult -from textual.widgets import DataTable - - -class TableApp(App): - def compose(self) -> ComposeResult: - yield DataTable() - - -async def test_table_clear() -> None: - """Check DataTable.clear""" - - app = TableApp() - async with app.run_test() as pilot: - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - assert table.row_count == 0 - table.add_row("Hello", "World!") - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {0: ["Hello", "World!"]} - assert table.row_count == 1 - table.clear() - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {} - assert table.row_count == 0 - - -async def test_table_clear_with_columns() -> None: - """Check DataTable.clear(columns=True)""" - - app = TableApp() - async with app.run_test() as pilot: - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - assert table.row_count == 0 - table.add_row("Hello", "World!") - assert [col.label for col in table.columns] == [Text("foo"), Text("bar")] - assert table.data == {0: ["Hello", "World!"]} - assert table.row_count == 1 - table.clear(columns=True) - assert [col.label for col in table.columns] == [] - assert table.data == {} - assert table.row_count == 0 - - -async def test_table_add_row() -> None: - app = TableApp() - async with app.run_test(): - table = app.query_one(DataTable) - table.add_columns("foo", "bar") - - assert table.columns[0].width == 3 - assert table.columns[1].width == 3 - table.add_row("Hello", "World!") - await asyncio.sleep(0) - assert table.columns[0].content_width == 5 - assert table.columns[1].content_width == 6 - - table.add_row("Hello World!!!", "fo") - await asyncio.sleep(0) - assert table.columns[0].content_width == 14 - assert table.columns[1].content_width == 6 diff --git a/tests/test_two_way_dict.py b/tests/test_two_way_dict.py new file mode 100644 index 000000000..9178f6fdb --- /dev/null +++ b/tests/test_two_way_dict.py @@ -0,0 +1,45 @@ +import pytest + +from textual._two_way_dict import TwoWayDict + + +@pytest.fixture +def two_way_dict(): + return TwoWayDict( + { + 1: 10, + 2: 20, + 3: 30, + } + ) + + +def test_get(two_way_dict): + assert two_way_dict.get(1) == 10 + + +def test_get_key(two_way_dict): + assert two_way_dict.get_key(30) == 3 + + +def test_set_item(two_way_dict): + two_way_dict[40] = 400 + assert two_way_dict.get(40) == 400 + assert two_way_dict.get_key(400) == 40 + + +def test_len(two_way_dict): + assert len(two_way_dict) == 3 + + +def test_delitem(two_way_dict): + assert two_way_dict.get(3) == 30 + assert two_way_dict.get_key(30) == 3 + del two_way_dict[3] + assert two_way_dict.get(3) is None + assert two_way_dict.get_key(30) is None + + +def test_contains(two_way_dict): + assert 1 in two_way_dict + assert 10 not in two_way_dict