diff --git a/.coveragerc b/.coveragerc
index d16dd221a..087a1674f 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -7,3 +7,4 @@ exclude_lines =
if TYPE_CHECKING:
if __name__ == "__main__":
@overload
+ __rich_repr__
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6df3da39e..e08383776 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,23 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Added Shift+scroll wheel and ctrl+scroll wheel to scroll horizontally
- Added `Tree.action_toggle_node` to toggle a node without selecting, and bound it to Space https://github.com/Textualize/textual/issues/1433
- Added `Tree.reset` to fully reset a `Tree` https://github.com/Textualize/textual/issues/1437
+- Added `DataTable.sort` to sort rows https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.get_cell` to retrieve a cell by column/row keys https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.get_cell_at` to retrieve a cell by coordinate https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.update_cell` to update a cell by column/row keys https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.update_cell_at` to update a cell at a coordinate https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.ordered_rows` property to retrieve `Row`s as they're currently ordered https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.ordered_columns` property to retrieve `Column`s as they're currently ordered https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.coordinate_to_cell_key` to find the key for the cell at a coordinate https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.is_valid_coordinate` https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.is_valid_row_index` https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.is_valid_column_index` https://github.com/Textualize/textual/pull/1638
+- Added attributes to events emitted from `DataTable` indicating row/column/cell keys https://github.com/Textualize/textual/pull/1638
+- Added `DataTable.get_row` to retrieve the values from a row by key https://github.com/Textualize/textual/pull/1786
+- Added `DataTable.get_row_at` to retrieve the values from a row by index https://github.com/Textualize/textual/pull/1786
+- Added `DataTable.get_column` to retrieve the values from a column by key https://github.com/Textualize/textual/pull/1786
+- Added `DataTable.get_column_at` to retrieve the values from a column by index https://github.com/Textualize/textual/pull/1786
+- Added `DataTable.HeaderSelected` which is posted when header label clicked https://github.com/Textualize/textual/pull/1788
- Added `DOMNode.watch` and `DOMNode.is_attached` methods https://github.com/Textualize/textual/pull/1750
- Added `DOMNode.css_tree` which is a renderable that shows the DOM and CSS https://github.com/Textualize/textual/pull/1778
- Added `DOMNode.children_view` which is a view on to a nodes children list, use for querying https://github.com/Textualize/textual/pull/1778
@@ -27,6 +44,21 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Breaking change: `TreeNode` can no longer be imported from `textual.widgets`; it is now available via `from textual.widgets.tree import TreeNode`. https://github.com/Textualize/textual/pull/1637
- `Tree` now shows a (subdued) cursor for a highlighted node when focus has moved elsewhere https://github.com/Textualize/textual/issues/1471
+- `DataTable.add_row` now accepts `key` argument to uniquely identify the row https://github.com/Textualize/textual/pull/1638
+- `DataTable.add_column` now accepts `key` argument to uniquely identify the column https://github.com/Textualize/textual/pull/1638
+- `DataTable.add_row` and `DataTable.add_column` now return lists of keys identifying the added rows/columns https://github.com/Textualize/textual/pull/1638
+- Breaking change: `DataTable.get_cell_value` renamed to `DataTable.get_value_at` https://github.com/Textualize/textual/pull/1638
+- `DataTable.row_count` is now a property https://github.com/Textualize/textual/pull/1638
+- Breaking change: `DataTable.cursor_cell` renamed to `DataTable.cursor_coordinate` https://github.com/Textualize/textual/pull/1638
+ - The method `validate_cursor_cell` was renamed to `validate_cursor_coordinate`.
+ - The method `watch_cursor_cell` was renamed to `watch_cursor_coordinate`.
+- Breaking change: `DataTable.hover_cell` renamed to `DataTable.hover_coordinate` https://github.com/Textualize/textual/pull/1638
+ - The method `validate_hover_cell` was renamed to `validate_hover_coordinate`.
+- Breaking change: `DataTable.data` structure changed, and will be made private in upcoming release https://github.com/Textualize/textual/pull/1638
+- Breaking change: `DataTable.refresh_cell` was renamed to `DataTable.refresh_coordinate` https://github.com/Textualize/textual/pull/1638
+- Breaking change: `DataTable.get_row_height` now takes a `RowKey` argument instead of a row index https://github.com/Textualize/textual/pull/1638
+- Breaking change: `DataTable.data` renamed to `DataTable._data` (it's now private) https://github.com/Textualize/textual/pull/1786
+- The `_filter` module was made public (now called `filter`) https://github.com/Textualize/textual/pull/1638
- Breaking change: renamed `Checkbox` to `Switch` https://github.com/Textualize/textual/issues/1746
- `App.install_screen` name is no longer optional https://github.com/Textualize/textual/pull/1778
- `App.query` now only includes the current screen https://github.com/Textualize/textual/pull/1778
@@ -47,6 +79,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Fixed issue with renderable width calculation https://github.com/Textualize/textual/issues/1685
- Fixed issue with app not processing Paste event https://github.com/Textualize/textual/issues/1666
- Fixed glitch with view position with auto width inputs https://github.com/Textualize/textual/issues/1693
+- Fixed `DataTable` "selected" events containing wrong coordinates when mouse was used https://github.com/Textualize/textual/issues/1723
### Removed
diff --git a/docs/examples/widgets/data_table.py b/docs/examples/widgets/data_table.py
index 74d9b76e4..f409bb45c 100644
--- a/docs/examples/widgets/data_table.py
+++ b/docs/examples/widgets/data_table.py
@@ -1,18 +1,18 @@
-import csv
-import io
-
from textual.app import App, ComposeResult
from textual.widgets import DataTable
-CSV = """lane,swimmer,country,time
-4,Joseph Schooling,Singapore,50.39
-2,Michael Phelps,United States,51.14
-5,Chad le Clos,South Africa,51.14
-6,László Cseh,Hungary,51.14
-3,Li Zhuhao,China,51.26
-8,Mehdy Metella,France,51.58
-7,Tom Shields,United States,51.73
-1,Aleksandr Sadovnikov,Russia,51.84"""
+ROWS = [
+ ("lane", "swimmer", "country", "time"),
+ (4, "Joseph Schooling", "Singapore", 50.39),
+ (2, "Michael Phelps", "United States", 51.14),
+ (5, "Chad le Clos", "South Africa", 51.14),
+ (6, "László Cseh", "Hungary", 51.14),
+ (3, "Li Zhuhao", "China", 51.26),
+ (8, "Mehdy Metella", "France", 51.58),
+ (7, "Tom Shields", "United States", 51.73),
+ (1, "Aleksandr Sadovnikov", "Russia", 51.84),
+ (10, "Darren Burns", "Scotland", 51.84),
+]
class TableApp(App):
@@ -21,11 +21,11 @@ class TableApp(App):
def on_mount(self) -> None:
table = self.query_one(DataTable)
- rows = csv.reader(io.StringIO(CSV))
+ rows = iter(ROWS)
table.add_columns(*next(rows))
table.add_rows(rows)
+app = TableApp()
if __name__ == "__main__":
- app = TableApp()
app.run()
diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md
index 9dd918d9d..7cc0a00d6 100644
--- a/docs/widgets/data_table.md
+++ b/docs/widgets/data_table.md
@@ -22,17 +22,17 @@ The example below populates a table with CSV data.
## Reactive Attributes
-| Name | Type | Default | Description |
-|-----------------|---------------------------------------------|--------------------|---------------------------------------------------------|
-| `show_header` | `bool` | `True` | Show the table header |
-| `fixed_rows` | `int` | `0` | Number of fixed rows (rows which do not scroll) |
-| `fixed_columns` | `int` | `0` | Number of fixed columns (columns which do not scroll) |
-| `zebra_stripes` | `bool` | `False` | Display alternating colors on rows |
-| `header_height` | `int` | `1` | Height of header row |
-| `show_cursor` | `bool` | `True` | Show the cursor |
-| `cursor_type` | `str` | `"cell"` | One of `"cell"`, `"row"`, `"column"`, or `"none"` |
-| `cursor_cell` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinates of the cell the cursor is currently on |
-| `hover_cell` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinates of the cell the _mouse_ cursor is above |
+| Name | Type | Default | Description |
+|---------------------|---------------------------------------------|--------------------|-------------------------------------------------------|
+| `show_header` | `bool` | `True` | Show the table header |
+| `fixed_rows` | `int` | `0` | Number of fixed rows (rows which do not scroll) |
+| `fixed_columns` | `int` | `0` | Number of fixed columns (columns which do not scroll) |
+| `zebra_stripes` | `bool` | `False` | Display alternating colors on rows |
+| `header_height` | `int` | `1` | Height of header row |
+| `show_cursor` | `bool` | `True` | Show the cursor |
+| `cursor_type` | `str` | `"cell"` | One of `"cell"`, `"row"`, `"column"`, or `"none"` |
+| `cursor_coordinate` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The current coordinate of the cursor |
+| `hover_coordinate` | [Coordinate][textual.coordinate.Coordinate] | `Coordinate(0, 0)` | The coordinate the _mouse_ cursor is above |
## Messages
@@ -48,6 +48,8 @@ The example below populates a table with CSV data.
### ::: textual.widgets.DataTable.ColumnSelected
+### ::: textual.widgets.DataTable.HeaderSelected
+
## Bindings
The data table widget defines directly the following bindings:
diff --git a/src/textual/_styles_cache.py b/src/textual/_styles_cache.py
index 77cbfb4de..c760fcfe3 100644
--- a/src/textual/_styles_cache.py
+++ b/src/textual/_styles_cache.py
@@ -8,7 +8,7 @@ from rich.segment import Segment
from rich.style import Style
from ._border import get_box, render_row
-from ._filter import LineFilter
+from .filter import LineFilter
from ._opacity import _apply_opacity
from ._segment_tools import line_pad, line_trim
from .color import Color
diff --git a/src/textual/_two_way_dict.py b/src/textual/_two_way_dict.py
new file mode 100644
index 000000000..d733edcdc
--- /dev/null
+++ b/src/textual/_two_way_dict.py
@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+from typing import Generic, TypeVar
+
+Key = TypeVar("Key")
+Value = TypeVar("Value")
+
+
+class TwoWayDict(Generic[Key, Value]):
+ """
+ A two-way mapping offering O(1) access in both directions.
+
+ Wraps two dictionaries and uses them to provide efficient access to
+ both values (given keys) and keys (given values).
+ """
+
+ def __init__(self, initial: dict[Key, Value]) -> None:
+ self._forward: dict[Key, Value] = initial
+ self._reverse: dict[Value, Key] = {value: key for key, value in initial.items()}
+
+ def __setitem__(self, key: Key, value: Value) -> None:
+ # TODO: Duplicate values need to be managed to ensure consistency,
+ # decide on best approach.
+ self._forward.__setitem__(key, value)
+ self._reverse.__setitem__(value, key)
+
+ def __delitem__(self, key: Key) -> None:
+ value = self._forward[key]
+ self._forward.__delitem__(key)
+ self._reverse.__delitem__(value)
+
+ def get(self, key: Key) -> Value:
+ """Given a key, efficiently lookup and return the associated value.
+
+ Args:
+ key: The key
+
+ Returns:
+ The value
+ """
+ return self._forward.get(key)
+
+ def get_key(self, value: Value) -> Key:
+ """Given a value, efficiently lookup and return the associated key.
+
+ Args:
+ value: The value
+
+ Returns:
+ The key
+ """
+ return self._reverse.get(value)
+
+ def contains_value(self, value: Value) -> bool:
+ """Check if `value` is a value within this TwoWayDict.
+
+ Args:
+ value: The value to check.
+
+ Returns:
+ True if the value is within the values of this dict.
+ """
+ return value in self._reverse
+
+ def __len__(self):
+ return len(self._forward)
+
+ def __contains__(self, item: Key) -> bool:
+ return item in self._forward
diff --git a/src/textual/app.py b/src/textual/app.py
index 0a425c839..6bca62824 100644
--- a/src/textual/app.py
+++ b/src/textual/app.py
@@ -48,7 +48,6 @@ from ._asyncio import create_task
from ._callback import invoke
from ._context import active_app
from ._event_broker import NoHandler, extract_handler_actions
-from ._filter import LineFilter, Monochrome
from ._path import _make_path_object_relative
from ._wait import wait_for_idle
from .actions import SkipAction
@@ -62,6 +61,7 @@ from .driver import Driver
from .drivers.headless_driver import HeadlessDriver
from .features import FeatureFlag, parse_features
from .file_monitor import FileMonitor
+from .filter import LineFilter, Monochrome
from .geometry import Offset, Region, Size
from .keys import REPLACED_KEYS, _get_key_display
from .messages import CallbackType
diff --git a/src/textual/_filter.py b/src/textual/filter.py
similarity index 100%
rename from src/textual/_filter.py
rename to src/textual/filter.py
diff --git a/src/textual/message.py b/src/textual/message.py
index 5f46cc97a..7ed72789b 100644
--- a/src/textual/message.py
+++ b/src/textual/message.py
@@ -10,7 +10,6 @@ from .case import camel_to_snake
if TYPE_CHECKING:
from .message_pump import MessagePump
- from .widget import Widget
@rich.repr.auto
diff --git a/src/textual/render.py b/src/textual/render.py
index 8911c4263..c1003b062 100644
--- a/src/textual/render.py
+++ b/src/textual/render.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+from rich.cells import cell_len
from rich.console import Console, RenderableType
from rich.protocol import rich_cast
@@ -22,6 +23,9 @@ def measure(
Returns:
Width in cells
"""
+ if isinstance(renderable, str):
+ return cell_len(renderable)
+
width = default
renderable = rich_cast(renderable)
get_console_width = getattr(renderable, "__rich_measure__", None)
diff --git a/src/textual/strip.py b/src/textual/strip.py
index c10a649a9..53d8a8e9f 100644
--- a/src/textual/strip.py
+++ b/src/textual/strip.py
@@ -9,7 +9,7 @@ from rich.segment import Segment
from rich.style import Style, StyleType
from ._cache import FIFOCache
-from ._filter import LineFilter
+from .filter import LineFilter
from ._segment_tools import index_to_cell_position
diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py
index 4cc540c94..b9e48d326 100644
--- a/src/textual/widgets/_data_table.py
+++ b/src/textual/widgets/_data_table.py
@@ -1,8 +1,10 @@
from __future__ import annotations
-from dataclasses import dataclass, field
+import functools
+from dataclasses import dataclass
from itertools import chain, zip_longest
-from typing import Generic, Iterable, cast
+from operator import itemgetter
+from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
import rich.repr
from rich.console import RenderableType
@@ -11,11 +13,12 @@ from rich.protocol import is_renderable
from rich.segment import Segment
from rich.style import Style
from rich.text import Text, TextType
-from typing_extensions import ClassVar, Literal, TypeVar
+from typing_extensions import Literal, TypeAlias
-from .. import events, messages
+from .. import events
from .._cache import LRUCache
from .._segment_tools import line_crop
+from .._two_way_dict import TwoWayDict
from .._types import SegmentLines
from ..binding import Binding, BindingType
from ..coordinate import Coordinate
@@ -26,40 +29,140 @@ from ..render import measure
from ..scroll_view import ScrollView
from ..strip import Strip
+CellCacheKey: TypeAlias = "tuple[RowKey, ColumnKey, Style, bool, bool, int]"
+LineCacheKey: TypeAlias = (
+ "tuple[int, int, int, int, Coordinate, Coordinate, Style, CursorType, bool, int]"
+)
+RowCacheKey: TypeAlias = (
+ "tuple[RowKey, int, Style, Coordinate, Coordinate, CursorType, bool, bool, int]"
+)
CursorType = Literal["cell", "row", "column", "none"]
-CELL: CursorType = "cell"
CellType = TypeVar("CellType")
class CellDoesNotExist(Exception):
- pass
+ """The cell key/index was invalid.
+
+ Raised when the user supplies coordinates or cell keys which
+ do not exist in the DataTable."""
-def default_cell_formatter(obj: object) -> RenderableType | None:
- """Format a cell in to a renderable.
+class RowDoesNotExist(Exception):
+ """Raised when the user supplies a row index or row key which does
+ not exist in the DataTable (e.g. out of bounds index, invalid key)"""
+
+
+class ColumnDoesNotExist(Exception):
+ """Raised when the user supplies a column index or column key which does
+ not exist in the DataTable (e.g. out of bounds index, invalid key)"""
+
+
+class DuplicateKey(Exception):
+ """The key supplied already exists.
+
+ Raised when the RowKey or ColumnKey provided already refers to
+ an existing row or column in the DataTable. Keys must be unique."""
+
+
+@functools.total_ordering
+class StringKey:
+ """An object used as a key in a mapping.
+
+ It can optionally wrap a string,
+ and lookups into a map using the object behave the same as lookups using
+ the string itself."""
+
+ value: str | None
+
+ def __init__(self, value: str | None = None):
+ self.value = value
+
+ def __hash__(self):
+ # If a string is supplied, we use the hash of the string. If no string was
+ # supplied, we use the default hash to ensure uniqueness amongst instances.
+ return hash(self.value) if self.value is not None else id(self)
+
+ def __eq__(self, other: object) -> bool:
+ # Strings will match Keys containing the same string value.
+ # Otherwise, you'll need to supply the exact same key object.
+ if isinstance(other, str):
+ return self.value == other
+ elif isinstance(other, StringKey):
+ if self.value is not None and other.value is not None:
+ return self.value == other.value
+ else:
+ return hash(self) == hash(other)
+ else:
+ raise NotImplemented
+
+ def __lt__(self, other):
+ if isinstance(other, str):
+ return self.value < other
+ elif isinstance(other, StringKey):
+ return self.value < other.value
+ else:
+ raise NotImplemented
+
+ def __rich_repr__(self):
+ yield "value", self.value
+
+
+class RowKey(StringKey):
+ """Uniquely identifies a row in the DataTable.
+
+ Even if the visual location
+ of the row changes due to sorting or other modifications, a key will always
+ refer to the same row."""
+
+
+class ColumnKey(StringKey):
+ """Uniquely identifies a column in the DataTable.
+
+ Even if the visual location
+ of the column changes due to sorting or other modifications, a key will always
+ refer to the same column."""
+
+
+class CellKey(NamedTuple):
+ """A unique identifier for a cell in the DataTable.
+
+ Even if the cell changes
+ visual location (i.e. moves to a different coordinate in the table), this key
+ can still be used to retrieve it, regardless of where it currently is."""
+
+ row_key: RowKey
+ column_key: ColumnKey
+
+ def __rich_repr__(self):
+ yield "row_key", self.row_key
+ yield "column_key", self.column_key
+
+
+def default_cell_formatter(obj: object) -> RenderableType:
+ """Convert a cell into a Rich renderable for display.
Args:
obj: Data for a cell.
Returns:
- A renderable or None if the object could not be rendered.
+ A renderable to be displayed which represents the data.
"""
if isinstance(obj, str):
return Text.from_markup(obj)
+ if isinstance(obj, float):
+ return f"{obj:.2f}"
if not is_renderable(obj):
- return None
+ return str(obj)
return cast(RenderableType, obj)
@dataclass
class Column:
- """Table column."""
+ """Metadata for a column in the DataTable."""
+ key: ColumnKey
label: Text
width: int = 0
- visible: bool = False
- index: int = 0
-
content_width: int = 0
auto_width: bool = False
@@ -75,12 +178,10 @@ class Column:
@dataclass
class Row:
- """Table row."""
+ """Metadata for a row in the DataTable."""
- index: int
+ key: RowKey
height: int
- y: int
- cell_renderables: list[RenderableType] = field(default_factory=list)
class DataTable(ScrollView, Generic[CellType], can_focus=True):
@@ -182,131 +283,183 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
zebra_stripes = Reactive(False)
header_height = Reactive(1)
show_cursor = Reactive(True)
- cursor_type = Reactive(CELL)
+ cursor_type = Reactive("cell")
- cursor_cell: Reactive[Coordinate] = Reactive(
+ cursor_coordinate: Reactive[Coordinate] = Reactive(
Coordinate(0, 0), repaint=False, always_update=True
)
- hover_cell: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False)
+ hover_coordinate: Reactive[Coordinate] = Reactive(Coordinate(0, 0), repaint=False)
class CellHighlighted(Message, bubble=True):
"""Posted when the cursor moves to highlight a new cell.
- It's only relevant when the `cursor_type` is `"cell"`.
- It's also posted when the cell cursor is re-enabled (by setting `show_cursor=True`),
- and when the cursor type is changed to `"cell"`. Can be handled using
- `on_data_table_cell_highlighted` in a subclass of `DataTable` or in a parent
- widget in the DOM.
- Attributes:
- value: The value in the highlighted cell.
- coordinate: The coordinate of the highlighted cell.
+ This is only relevant when the `cursor_type` is `"cell"`.
+ It's also posted when the cell cursor is
+ re-enabled (by setting `show_cursor=True`), and when the cursor type is
+ changed to `"cell"`. Can be handled using `on_data_table_cell_highlighted` in
+ a subclass of `DataTable` or in a parent widget in the DOM.
"""
def __init__(
- self, sender: DataTable, value: CellType, coordinate: Coordinate
+ self,
+ sender: DataTable,
+ value: CellType,
+ coordinate: Coordinate,
+ cell_key: CellKey,
) -> None:
self.value: CellType = value
+ """The value in the highlighted cell."""
self.coordinate: Coordinate = coordinate
+ """The coordinate of the highlighted cell."""
+ self.cell_key: CellKey = cell_key
+ """The key for the highlighted cell."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "value", self.value
yield "coordinate", self.coordinate
+ yield "cell_key", self.cell_key
class CellSelected(Message, bubble=True):
"""Posted by the `DataTable` widget when a cell is selected.
- It's only relevant when the `cursor_type` is `"cell"`. Can be handled using
+
+ This is only relevant when the `cursor_type` is `"cell"`. Can be handled using
`on_data_table_cell_selected` in a subclass of `DataTable` or in a parent
widget in the DOM.
-
- Attributes:
- value: The value in the cell that was selected.
- coordinate: The coordinate of the cell that was selected.
"""
def __init__(
- self, sender: DataTable, value: CellType, coordinate: Coordinate
+ self,
+ sender: DataTable,
+ value: CellType,
+ coordinate: Coordinate,
+ cell_key: CellKey,
) -> None:
self.value: CellType = value
+ """The value in the cell that was selected."""
self.coordinate: Coordinate = coordinate
+ """The coordinate of the cell that was selected."""
+ self.cell_key: CellKey = cell_key
+ """The key for the selected cell."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "value", self.value
yield "coordinate", self.coordinate
+ yield "cell_key", self.cell_key
class RowHighlighted(Message, bubble=True):
- """Posted when a row is highlighted. This message is only posted when the
- `cursor_type` is set to `"row"`. Can be handled using `on_data_table_row_highlighted`
- in a subclass of `DataTable` or in a parent widget in the DOM.
+ """Posted when a row is highlighted.
- Attributes:
- cursor_row: The y-coordinate of the cursor that highlighted the row.
+ This message is only posted when the
+ `cursor_type` is set to `"row"`. Can be handled using
+ `on_data_table_row_highlighted` in a subclass of `DataTable` or in a parent
+ widget in the DOM.
"""
- def __init__(self, sender: DataTable, cursor_row: int) -> None:
+ def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None:
self.cursor_row: int = cursor_row
+ """The y-coordinate of the cursor that highlighted the row."""
+ self.row_key: RowKey = row_key
+ """The key of the row that was highlighted."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "cursor_row", self.cursor_row
+ yield "row_key", self.row_key
class RowSelected(Message, bubble=True):
- """Posted when a row is selected. This message is only posted when the
+ """Posted when a row is selected.
+
+ This message is only posted when the
`cursor_type` is set to `"row"`. Can be handled using
`on_data_table_row_selected` in a subclass of `DataTable` or in a parent
widget in the DOM.
-
- Attributes:
- cursor_row: The y-coordinate of the cursor that made the selection.
"""
- def __init__(self, sender: DataTable, cursor_row: int) -> None:
+ def __init__(self, sender: DataTable, cursor_row: int, row_key: RowKey) -> None:
self.cursor_row: int = cursor_row
+ """The y-coordinate of the cursor that made the selection."""
+ self.row_key: RowKey = row_key
+ """The key of the row that was selected."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "cursor_row", self.cursor_row
+ yield "row_key", self.row_key
class ColumnHighlighted(Message, bubble=True):
- """Posted when a column is highlighted. This message is only posted when the
+ """Posted when a column is highlighted.
+
+ This message is only posted when the
`cursor_type` is set to `"column"`. Can be handled using
`on_data_table_column_highlighted` in a subclass of `DataTable` or in a parent
widget in the DOM.
-
- Attributes:
- cursor_column: The x-coordinate of the column that was highlighted.
"""
- def __init__(self, sender: DataTable, cursor_column: int) -> None:
+ def __init__(
+ self, sender: DataTable, cursor_column: int, column_key: ColumnKey
+ ) -> None:
self.cursor_column: int = cursor_column
+ """The x-coordinate of the column that was highlighted."""
+ self.column_key = column_key
+ """The key of the column that was highlighted."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "cursor_column", self.cursor_column
+ yield "column_key", self.column_key
class ColumnSelected(Message, bubble=True):
- """Posted when a column is selected. This message is only posted when the
+ """Posted when a column is selected.
+
+ This message is only posted when the
`cursor_type` is set to `"column"`. Can be handled using
`on_data_table_column_selected` in a subclass of `DataTable` or in a parent
widget in the DOM.
-
- Attributes:
- cursor_column: The x-coordinate of the column that was selected.
"""
- def __init__(self, sender: DataTable, cursor_column: int) -> None:
+ def __init__(
+ self, sender: DataTable, cursor_column: int, column_key: ColumnKey
+ ) -> None:
self.cursor_column: int = cursor_column
+ """The x-coordinate of the column that was selected."""
+ self.column_key = column_key
+ """The key of the column that was selected."""
super().__init__(sender)
def __rich_repr__(self) -> rich.repr.Result:
yield "sender", self.sender
yield "cursor_column", self.cursor_column
+ yield "column_key", self.column_key
+
+ class HeaderSelected(Message, bubble=True):
+ """Posted when a column header/label is clicked."""
+
+ def __init__(
+ self,
+ sender: DataTable,
+ column_key: ColumnKey,
+ column_index: int,
+ label: Text,
+ ):
+ self.column_key = column_key
+ """The key for the column."""
+ self.column_index = column_index
+ """The index for the column."""
+ self.label = label
+ """The text of the label."""
+ super().__init__(sender)
+
+ def __rich_repr__(self) -> rich.repr.Result:
+ yield "sender", self.sender
+ yield "column_key", self.column_key
+ yield "label", self.label.plain
def __init__(
self,
@@ -322,69 +475,280 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
classes: str | None = None,
) -> None:
super().__init__(name=name, id=id, classes=classes)
+ self._data: dict[RowKey, dict[ColumnKey, CellType]] = {}
+ """Contains the cells of the table, indexed by row key and column key.
+ The final positioning of a cell on screen cannot be determined solely by this
+ structure. Instead, we must check _row_locations and _column_locations to find
+ where each cell currently resides in space."""
+
+ self.columns: dict[ColumnKey, Column] = {}
+ """Metadata about the columns of the table, indexed by their key."""
+ self.rows: dict[RowKey, Row] = {}
+ """Metadata about the rows of the table, indexed by their key."""
+
+ # Keep tracking of key -> index for rows/cols. These allow us to retrieve,
+ # given a row or column key, the index that row or column is currently
+ # present at, and mean that rows and columns are location independent - they
+ # can move around without requiring us to modify the underlying data.
+ self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
+ """Maps row keys to row indices which represent row order."""
+ self._column_locations: TwoWayDict[ColumnKey, int] = TwoWayDict({})
+ """Maps column keys to column indices which represent column order."""
- self.columns: list[Column] = []
- self.rows: dict[int, Row] = {}
- self.data: dict[int, list[CellType]] = {}
- self.row_count = 0
- self._y_offsets: list[tuple[int, int]] = []
self._row_render_cache: LRUCache[
- tuple[int, int, Style, int, int], tuple[SegmentLines, SegmentLines]
- ]
- self._row_render_cache = LRUCache(1000)
- self._cell_render_cache: LRUCache[
- tuple[int, int, Style, bool, bool], SegmentLines
- ]
- self._cell_render_cache = LRUCache(10000)
- self._line_cache: LRUCache[tuple[int, int, int, int, int, int, Style], Strip]
- self._line_cache = LRUCache(1000)
+ RowCacheKey, tuple[SegmentLines, SegmentLines]
+ ] = LRUCache(1000)
+ """For each row (a row can have a height of multiple lines), we maintain a
+ cache of the fixed and scrollable lines within that row to minimise how often
+ we need to re-render it. """
+ self._cell_render_cache: LRUCache[CellCacheKey, SegmentLines] = LRUCache(10000)
+ """Cache for individual cells."""
+ self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1000)
+ """Cache for lines within rows."""
+ self._offset_cache: LRUCache[int, list[tuple[RowKey, int]]] = LRUCache(1)
+ """Cached y_offset - key is update_count - see y_offsets property for more
+ information """
+ self._ordered_row_cache: LRUCache[tuple[int, int], list[Row]] = LRUCache(1)
+ """Caches row ordering - key is (num_rows, update_count)."""
- self._line_no = 0
self._require_update_dimensions: bool = False
- self._new_rows: set[int] = set()
+ """Set to re-calculate dimensions on idle."""
+ self._new_rows: set[RowKey] = set()
+ """Tracking newly added rows to be used in calculation of dimensions on idle."""
+ self._updated_cells: set[CellKey] = set()
+ """Track which cells were updated, so that we can refresh them once on idle."""
self.show_header = show_header
- self.fixed_rows = fixed_rows
- self.fixed_columns = fixed_columns
- self.zebra_stripes = zebra_stripes
+ """Show/hide the header row (the row of column labels)."""
self.header_height = header_height
+ """The height of the header row (the row of column labels)."""
+ self.fixed_rows = fixed_rows
+ """The number of rows to fix (prevented from scrolling)."""
+ self.fixed_columns = fixed_columns
+ """The number of columns to fix (prevented from scrolling)."""
+ self.zebra_stripes = zebra_stripes
+ """Apply zebra effect on row backgrounds (light, dark, light, dark, ...)."""
self.show_cursor = show_cursor
+ """Show/hide both the keyboard and hover cursor."""
self._show_hover_cursor = False
+ """Used to hide the mouse hover cursor when the user uses the keyboard."""
+ self._update_count = 0
+ """Number of update (INCLUDING SORT) operations so far. Used for cache invalidation."""
+ self._header_row_key = RowKey()
+ """The header is a special row - not part of the data. Retrieve via this key."""
@property
def hover_row(self) -> int:
- return self.hover_cell.row
+ """The index of the row that the mouse cursor is currently hovering above."""
+ return self.hover_coordinate.row
@property
def hover_column(self) -> int:
- return self.hover_cell.column
+ """The index of the column that the mouse cursor is currently hovering above."""
+ return self.hover_coordinate.column
@property
def cursor_row(self) -> int:
- return self.cursor_cell.row
+ """The index of the row that the DataTable cursor is currently on."""
+ return self.cursor_coordinate.row
@property
def cursor_column(self) -> int:
- return self.cursor_cell.column
+ """The index of the column that the DataTable cursor is currently on."""
+ return self.cursor_coordinate.column
- def get_cell_value(self, coordinate: Coordinate) -> CellType:
- """Get the value from the cell at the given coordinate.
+ @property
+ def row_count(self) -> int:
+ """The number of rows currently present in the DataTable."""
+ return len(self.rows)
+
+ @property
+ def _y_offsets(self) -> list[tuple[RowKey, int]]:
+ """Contains a 2-tuple for each line (not row!) of the DataTable. Given a
+ y-coordinate, we can index into this list to find which row that y-coordinate
+ lands on, and the y-offset *within* that row. The length of the returned list
+ is therefore the total height of all rows within the DataTable."""
+ y_offsets = []
+ if self._update_count in self._offset_cache:
+ y_offsets = self._offset_cache[self._update_count]
+ else:
+ for row in self.ordered_rows:
+ y_offsets += [(row.key, y) for y in range(row.height)]
+ self._offset_cache = y_offsets
+ return y_offsets
+
+ @property
+ def _total_row_height(self) -> int:
+ """The total height of all rows within the DataTable"""
+ return len(self._y_offsets)
+
+ def update_cell(
+ self,
+ row_key: RowKey | str,
+ column_key: ColumnKey | str,
+ value: CellType,
+ *,
+ update_width: bool = False,
+ ) -> None:
+ """Update the cell identified by the specified row key and column key.
+
+ Args:
+ row_key: The key identifying the row.
+ column_key: The key identifying the column.
+ value: The new value to put inside the cell.
+ update_width: Whether to resize the column width to accommodate
+ for the new cell content.
+
+ Raises:
+ CellDoesNotExist: When the supplied `row_key` and `column_key`
+ cannot be found in the table.
+ """
+ if isinstance(row_key, str):
+ row_key = RowKey(row_key)
+ if isinstance(column_key, str):
+ column_key = ColumnKey(column_key)
+
+ try:
+ self._data[row_key][column_key] = value
+ except KeyError:
+ raise CellDoesNotExist(
+ f"No cell exists for row_key={row_key!r}, column_key={column_key!r}."
+ ) from None
+ self._update_count += 1
+
+ # Recalculate widths if necessary
+ if update_width:
+ self._updated_cells.add(CellKey(row_key, column_key))
+ self._require_update_dimensions = True
+
+ self.refresh()
+
+ def update_cell_at(
+ self, coordinate: Coordinate, value: CellType, *, update_width: bool = False
+ ) -> None:
+ """Update the content inside the cell currently occupying the given coordinate.
+
+ Args:
+ coordinate: The coordinate to update the cell at.
+ value: The new value to place inside the cell.
+ update_width: Whether to resize the column width to accommodate
+ for the new cell content.
+ """
+ if not self.is_valid_coordinate(coordinate):
+ raise CellDoesNotExist(f"Coordinate {coordinate!r} is invalid.")
+
+ row_key, column_key = self.coordinate_to_cell_key(coordinate)
+ self.update_cell(row_key, column_key, value, update_width=update_width)
+
+ def get_cell(self, row_key: RowKey, column_key: ColumnKey) -> CellType:
+ """Given a row key and column key, return the value of the corresponding cell.
+
+ Args:
+ row_key: The row key of the cell.
+ column_key: The column key of the cell.
+
+ Returns:
+ The value of the cell identified by the row and column keys.
+ """
+ try:
+ cell_value = self._data[row_key][column_key]
+ except KeyError:
+ raise CellDoesNotExist(
+ f"No cell exists for row_key={row_key!r}, column_key={column_key!r}."
+ )
+ return cell_value
+
+ def get_cell_at(self, coordinate: Coordinate) -> CellType:
+ """Get the value from the cell occupying the given coordinate.
Args:
coordinate: The coordinate to retrieve the value from.
Returns:
- The value of the cell.
+ The value of the cell at the coordinate.
Raises:
CellDoesNotExist: If there is no cell with the given coordinate.
"""
- row, column = coordinate
- try:
- cell_value = self.data[row][column]
- except KeyError:
- raise CellDoesNotExist(f"No cell exists at {coordinate!r}") from None
- return cell_value
+ row_key, column_key = self.coordinate_to_cell_key(coordinate)
+ return self.get_cell(row_key, column_key)
+
+ def get_row(self, row_key: RowKey | str) -> list[CellType]:
+ """Get the values from the row identified by the given row key.
+
+ Args:
+ row_key: The key of the row.
+
+ Returns:
+ A list of the values contained within the row.
+
+ Raises:
+ RowDoesNotExist: When there is no row corresponding to the key.
+ """
+ if row_key not in self._row_locations:
+ raise RowDoesNotExist(f"Row key {row_key!r} is not valid.")
+ cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {})
+ ordered_row: list[CellType] = [
+ cell_mapping[column.key] for column in self.ordered_columns
+ ]
+ return ordered_row
+
+ def get_row_at(self, row_index: int) -> list[CellType]:
+ """Get the values from the cells in a row at a given index. This will
+ return the values from a row based on the rows _current position_ in
+ the table.
+
+ Args:
+ row_index: The index of the row.
+
+ Returns:
+ A list of the values contained in the row.
+
+ Raises:
+ RowDoesNotExist: If there is no row with the given index.
+ """
+ if not self.is_valid_row_index(row_index):
+ raise RowDoesNotExist(f"Row index {row_index!r} is not valid.")
+ row_key = self._row_locations.get_key(row_index)
+ return self.get_row(row_key)
+
+ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
+ """Get the values from the column identified by the given column key.
+
+ Args:
+ column_key: The key of the column.
+
+ Returns:
+ A generator which yields the cells in the column.
+
+ Raises:
+ ColumnDoesNotExist: If there is no column corresponding to the key.
+ """
+ if column_key not in self._column_locations:
+ raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")
+
+ data = self._data
+ for row_metadata in self.ordered_rows:
+ row_key = row_metadata.key
+ yield data[row_key][column_key]
+
+ def get_column_at(self, column_index: int) -> Iterable[CellType]:
+ """Get the values from the column at a given index.
+
+ Args:
+ column_index: The index of the column.
+
+ Returns:
+ A generator which yields the cells in the column.
+
+ Raises:
+ ColumnDoesNotExist: If there is no column with the given index.
+ """
+ if not self.is_valid_column_index(column_index):
+ raise ColumnDoesNotExist(f"Column index {column_index!r} is not valid.")
+
+ column_key = self._column_locations.get_key(column_index)
+ yield from self.get_column(column_key)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
@@ -392,12 +756,20 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._line_cache.clear()
self._styles_cache.clear()
- def get_row_height(self, row_index: int) -> int:
- if row_index == -1:
- return self.header_height
- return self.rows[row_index].height
+ def get_row_height(self, row_key: RowKey) -> int:
+ """Given a row key, return the height of that row in terminal cells.
- async def on_styles_updated(self, message: messages.StylesUpdated) -> None:
+ Args:
+ row_key: The key of the row.
+
+ Returns:
+ The height of the row, measured in terminal character cells.
+ """
+ if row_key is self._header_row_key:
+ return self.header_height
+ return self.rows[row_key].height
+
+ async def on_styles_updated(self) -> None:
self._clear_caches()
self.refresh()
@@ -408,34 +780,34 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# post the appropriate [Row|Column|Cell]Highlighted event.
self._scroll_cursor_into_view(animate=False)
if self.cursor_type == "cell":
- self._highlight_cell(self.cursor_cell)
+ self._highlight_coordinate(self.cursor_coordinate)
elif self.cursor_type == "row":
self._highlight_row(self.cursor_row)
elif self.cursor_type == "column":
self._highlight_column(self.cursor_column)
- def watch_show_header(self, show_header: bool) -> None:
+ def watch_show_header(self) -> None:
self._clear_caches()
- def watch_fixed_rows(self, fixed_rows: int) -> None:
+ def watch_fixed_rows(self) -> None:
self._clear_caches()
- def watch_zebra_stripes(self, zebra_stripes: bool) -> None:
+ def watch_zebra_stripes(self) -> None:
self._clear_caches()
- def watch_hover_cell(self, old: Coordinate, value: Coordinate) -> None:
- self.refresh_cell(*old)
- self.refresh_cell(*value)
+ def watch_hover_coordinate(self, old: Coordinate, value: Coordinate) -> None:
+ self.refresh_coordinate(old)
+ self.refresh_coordinate(value)
- def watch_cursor_cell(
+ def watch_cursor_coordinate(
self, old_coordinate: Coordinate, new_coordinate: Coordinate
) -> None:
if old_coordinate != new_coordinate:
# Refresh the old and the new cell, and post the appropriate
# message to tell users of the newly highlighted row/cell/column.
if self.cursor_type == "cell":
- self.refresh_cell(*old_coordinate)
- self._highlight_cell(new_coordinate)
+ self.refresh_coordinate(old_coordinate)
+ self._highlight_coordinate(new_coordinate)
elif self.cursor_type == "row":
self.refresh_row(old_coordinate.row)
self._highlight_row(new_coordinate.row)
@@ -443,37 +815,67 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self.refresh_column(old_coordinate.column)
self._highlight_column(new_coordinate.column)
- def _highlight_cell(self, coordinate: Coordinate) -> None:
+ def _highlight_coordinate(self, coordinate: Coordinate) -> None:
"""Apply highlighting to the cell at the coordinate, and post event."""
- self.refresh_cell(*coordinate)
+ self.refresh_coordinate(coordinate)
try:
- cell_value = self.get_cell_value(coordinate)
+ cell_value = self.get_cell_at(coordinate)
except CellDoesNotExist:
# The cell may not exist e.g. when the table is cleared.
# In that case, there's nothing for us to do here.
return
else:
+ cell_key = self.coordinate_to_cell_key(coordinate)
self.post_message_no_wait(
- DataTable.CellHighlighted(self, cell_value, coordinate)
+ DataTable.CellHighlighted(
+ self, cell_value, coordinate=coordinate, cell_key=cell_key
+ )
)
+ def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey:
+ """Return the key for the cell currently occupying this coordinate.
+
+ Args:
+ coordinate: The coordinate to exam the current cell key of.
+
+ Returns:
+ The key of the cell currently occupying this coordinate.
+
+ Raises:
+ CellDoesNotExist: If the coordinate is not valid.
+ """
+ if not self.is_valid_coordinate(coordinate):
+ raise CellDoesNotExist(f"No cell exists at {coordinate!r}.")
+ row_index, column_index = coordinate
+ row_key = self._row_locations.get_key(row_index)
+ column_key = self._column_locations.get_key(column_index)
+ return CellKey(row_key, column_key)
+
def _highlight_row(self, row_index: int) -> None:
"""Apply highlighting to the row at the given index, and post event."""
self.refresh_row(row_index)
- if row_index in self.data:
- self.post_message_no_wait(DataTable.RowHighlighted(self, row_index))
+ is_valid_row = row_index < len(self._data)
+ if is_valid_row:
+ row_key = self._row_locations.get_key(row_index)
+ self.post_message_no_wait(
+ DataTable.RowHighlighted(self, row_index, row_key)
+ )
def _highlight_column(self, column_index: int) -> None:
"""Apply highlighting to the column at the given index, and post event."""
self.refresh_column(column_index)
if column_index < len(self.columns):
- self.post_message_no_wait(DataTable.ColumnHighlighted(self, column_index))
+ column_key = self._column_locations.get_key(column_index)
+ self.post_message_no_wait(
+ DataTable.ColumnHighlighted(self, column_index, column_key)
+ )
- def validate_cursor_cell(self, value: Coordinate) -> Coordinate:
- return self._clamp_cursor_cell(value)
+ def validate_cursor_coordinate(self, value: Coordinate) -> Coordinate:
+ return self._clamp_cursor_coordinate(value)
- def _clamp_cursor_cell(self, cursor_cell: Coordinate) -> Coordinate:
- row, column = cursor_cell
+ def _clamp_cursor_coordinate(self, coordinate: Coordinate) -> Coordinate:
+ """Clamp a coordinate such that it falls within the boundaries of the table."""
+ row, column = coordinate
row = clamp(row, 0, self.row_count - 1)
column = clamp(column, self.fixed_columns, len(self.columns) - 1)
return Coordinate(row, column)
@@ -485,53 +887,89 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
# Refresh cells that were previously impacted by the cursor
# but may no longer be.
- row_index, column_index = self.cursor_cell
if old == "cell":
- self.refresh_cell(row_index, column_index)
+ self.refresh_coordinate(self.cursor_coordinate)
elif old == "row":
+ row_index, _ = self.cursor_coordinate
self.refresh_row(row_index)
elif old == "column":
+ _, column_index = self.cursor_coordinate
self.refresh_column(column_index)
self._scroll_cursor_into_view()
def _highlight_cursor(self) -> None:
- row_index, column_index = self.cursor_cell
+ """Applies the appropriate highlighting and raises the appropriate
+ [Row|Column|Cell]Highlighted event for the given cursor coordinate
+ and cursor type."""
+ row_index, column_index = self.cursor_coordinate
cursor_type = self.cursor_type
# Apply the highlighting to the newly relevant cells
if cursor_type == "cell":
- self._highlight_cell(self.cursor_cell)
+ self._highlight_coordinate(self.cursor_coordinate)
elif cursor_type == "row":
self._highlight_row(row_index)
elif cursor_type == "column":
self._highlight_column(column_index)
- def _update_dimensions(self, new_rows: Iterable[int]) -> None:
+ def _update_column_widths(self, updated_cells: set[CellKey]) -> None:
+ """Update the widths of the columns based on the newly updated cell widths."""
+ for row_key, column_key in updated_cells:
+ column = self.columns.get(column_key)
+ if column is None:
+ continue
+ console = self.app.console
+ label_width = measure(console, column.label, 1)
+ content_width = column.content_width
+ cell_value = self._data[row_key][column_key]
+
+ new_content_width = measure(console, default_cell_formatter(cell_value), 1)
+
+ if new_content_width < content_width:
+ cells_in_column = self.get_column(column_key)
+ cell_widths = [
+ measure(console, default_cell_formatter(cell), 1)
+ for cell in cells_in_column
+ ]
+ column.content_width = max([*cell_widths, label_width])
+ else:
+ column.content_width = max(new_content_width, label_width)
+
+ def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None:
"""Called to recalculate the virtual (scrollable) size."""
- for row_index in new_rows:
+ for row_key in new_rows:
+ row_index = self._row_locations.get(row_key)
+ if row_index is None:
+ continue
for column, renderable in zip(
- self.columns, self._get_row_renderables(row_index)
+ self.ordered_columns, self._get_row_renderables(row_index)
):
content_width = measure(self.app.console, renderable, 1)
column.content_width = max(column.content_width, content_width)
self._clear_caches()
- total_width = sum(column.render_width for column in self.columns)
+ total_width = sum(column.render_width for column in self.columns.values())
header_height = self.header_height if self.show_header else 0
self.virtual_size = Size(
total_width,
- len(self._y_offsets) + header_height,
+ self._total_row_height + header_height,
)
- def _get_cell_region(self, row_index: int, column_index: int) -> Region:
- """Get the region of the cell at the given coordinate (row_index, column_index)"""
- if row_index not in self.rows:
+ def _get_cell_region(self, coordinate: Coordinate) -> Region:
+ """Get the region of the cell at the given spatial coordinate."""
+ if not self.is_valid_coordinate(coordinate):
return Region(0, 0, 0, 0)
- row = self.rows[row_index]
- x = sum(column.render_width for column in self.columns[:column_index])
- width = self.columns[column_index].render_width
+
+ row_index, column_index = coordinate
+ row_key = self._row_locations.get_key(row_index)
+ row = self.rows[row_key]
+
+ # The x-coordinate of a cell is the sum of widths of cells to the left.
+ x = sum(column.render_width for column in self.ordered_columns[:column_index])
+ column_key = self._column_locations.get_key(column_index)
+ width = self.columns[column_key].render_width
height = row.height
- y = row.y
+ y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index])
if self.show_header:
y += self.header_height
cell_region = Region(x, y, width, height)
@@ -539,12 +977,14 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_row_region(self, row_index: int) -> Region:
"""Get the region of the row at the given index."""
- rows = self.rows
- if row_index < 0 or row_index >= len(rows):
+ if not self.is_valid_row_index(row_index):
return Region(0, 0, 0, 0)
- row = rows[row_index]
- row_width = sum(column.render_width for column in self.columns)
- y = row.y
+
+ rows = self.rows
+ row_key = self._row_locations.get_key(row_index)
+ row = rows[row_key]
+ row_width = sum(column.render_width for column in self.columns.values())
+ y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index])
if self.show_header:
y += self.header_height
row_region = Region(0, y, row_width, row.height)
@@ -552,14 +992,15 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _get_column_region(self, column_index: int) -> Region:
"""Get the region of the column at the given index."""
- columns = self.columns
- if column_index < 0 or column_index >= len(columns):
+ if not self.is_valid_column_index(column_index):
return Region(0, 0, 0, 0)
- x = sum(column.render_width for column in self.columns[:column_index])
- width = columns[column_index].render_width
+ columns = self.columns
+ x = sum(column.render_width for column in self.ordered_columns[:column_index])
+ column_key = self._column_locations.get_key(column_index)
+ width = columns[column_key].render_width
header_height = self.header_height if self.show_header else 0
- height = len(self._y_offsets) + header_height
+ height = self._total_row_height + header_height
full_column_region = Region(x, 0, width, height)
return full_column_region
@@ -569,77 +1010,97 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args:
columns: Also clear the columns. Defaults to False.
"""
- self.row_count = 0
self._clear_caches()
self._y_offsets.clear()
- self.data.clear()
+ self._data.clear()
self.rows.clear()
if columns:
self.columns.clear()
- self._line_no = 0
self._require_update_dimensions = True
- self.cursor_cell = Coordinate(0, 0)
- self.hover_cell = Coordinate(0, 0)
+ self.cursor_coordinate = Coordinate(0, 0)
+ self.hover_coordinate = Coordinate(0, 0)
self.refresh()
- def add_columns(self, *labels: TextType) -> None:
- """Add a number of columns.
-
- Args:
- *labels: Column headers.
- """
- for label in labels:
- self.add_column(label, width=None)
-
- def add_column(self, label: TextType, *, width: int | None = None) -> None:
+ def add_column(
+ self, label: TextType, *, width: int | None = None, key: str | None = None
+ ) -> ColumnKey:
"""Add a column to the table.
Args:
label: A str or Text object containing the label (shown top of column).
- width: Width of the column in cells or None to fit content. Defaults to None.
- """
- text_label = Text.from_markup(label) if isinstance(label, str) else label
+ width: Width of the column in cells or None to fit content.
+ key: A key which uniquely identifies this column.
+ If None, it will be generated for you.
- content_width = measure(self.app.console, text_label, 1)
+ Returns:
+ Uniquely identifies this column. Can be used to retrieve this column
+ regardless of its current location in the DataTable (it could have moved
+ after being added due to sorting/insertion/deletion of other columns).
+ """
+ column_key = ColumnKey(key)
+ if column_key in self._column_locations:
+ raise DuplicateKey(f"The column key {key!r} already exists.")
+ column_index = len(self.columns)
+ label = Text.from_markup(label) if isinstance(label, str) else label
+ content_width = measure(self.app.console, label, 1)
if width is None:
column = Column(
- text_label,
+ column_key,
+ label,
content_width,
- index=len(self.columns),
content_width=content_width,
auto_width=True,
)
else:
column = Column(
- text_label, width, content_width=content_width, index=len(self.columns)
+ column_key,
+ label,
+ width,
+ content_width=content_width,
)
- self.columns.append(column)
+ self.columns[column_key] = column
+ self._column_locations[column_key] = column_index
self._require_update_dimensions = True
self.check_idle()
- def add_row(self, *cells: CellType, height: int = 1) -> None:
- """Add a row.
+ return column_key
+
+ def add_row(
+ self, *cells: CellType, height: int = 1, key: str | None = None
+ ) -> RowKey:
+ """Add a row at the bottom of the DataTable.
Args:
*cells: Positional arguments should contain cell data.
- height: The height of a row (in lines). Defaults to 1.
+ height: The height of a row (in lines).
+ key: A key which uniquely identifies this row. If None, it will be generated
+ for you and returned.
+
+ Returns:
+ Uniquely identifies this row. Can be used to retrieve this row regardless
+ of its current location in the DataTable (it could have moved after
+ being added due to sorting or insertion/deletion of other rows).
"""
+ row_key = RowKey(key)
+ if row_key in self._row_locations:
+ raise DuplicateKey(f"The row key {row_key!r} already exists.")
+
+ # TODO: If there are no columns: do we generate them here?
+ # If we don't do this, users will be required to call add_column(s)
+ # Before they call add_row.
+
row_index = self.row_count
-
- self.data[row_index] = list(cells)
- self.rows[row_index] = Row(row_index, height, self._line_no)
-
- for line_no in range(height):
- self._y_offsets.append((row_index, line_no))
-
- self.row_count += 1
- self._line_no += height
-
- self._new_rows.add(row_index)
+ # Map the key of this row to its current index
+ self._row_locations[row_key] = row_index
+ self._data[row_key] = {
+ column.key: cell
+ for column, cell in zip_longest(self.ordered_columns, cells)
+ }
+ self.rows[row_key] = Row(row_key, height)
+ self._new_rows.add(row_key)
self._require_update_dimensions = True
- self.cursor_cell = self.cursor_cell
- self.check_idle()
+ self.cursor_coordinate = self.cursor_coordinate
# If a position has opened for the cursor to appear, where it previously
# could not (e.g. when there's no data in the table), then a highlighted
@@ -650,33 +1111,74 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
if cell_now_available and visible_cursor:
self._highlight_cursor()
- def add_rows(self, rows: Iterable[Iterable[CellType]]) -> None:
- """Add a number of rows.
+ self.check_idle()
+ return row_key
+
+ def add_columns(self, *labels: TextType) -> list[ColumnKey]:
+ """Add a number of columns.
+
+ Args:
+ *labels: Column headers.
+
+ Returns:
+ A list of the keys for the columns that were added. See
+ the `add_column` method docstring for more information on how
+ these keys are used.
+ """
+ column_keys = []
+ for label in labels:
+ column_key = self.add_column(label, width=None)
+ column_keys.append(column_key)
+ return column_keys
+
+ def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]:
+ """Add a number of rows at the bottom of the DataTable.
Args:
rows: Iterable of rows. A row is an iterable of cells.
+
+ Returns:
+ A list of the keys for the rows that were added. See
+ the `add_row` method docstring for more information on how
+ these keys are used.
"""
+ row_keys = []
for row in rows:
- self.add_row(*row)
+ row_key = self.add_row(*row)
+ row_keys.append(row_key)
+ return row_keys
def on_idle(self) -> None:
+ """Runs when the message pump is empty.
+
+ We use this for some expensive calculations like re-computing dimensions of the
+ whole DataTable and re-computing column widths after some cells
+ have been updated. This is more efficient in the case of high
+ frequency updates, ensuring we only do expensive computations once."""
if self._require_update_dimensions:
+ # Add the new rows *before* updating the column widths, since
+ # cells in a new row may influence the final width of a column
self._require_update_dimensions = False
new_rows = self._new_rows.copy()
self._new_rows.clear()
self._update_dimensions(new_rows)
- self.refresh()
- def refresh_cell(self, row_index: int, column_index: int) -> None:
- """Refresh a cell.
+ if self._updated_cells:
+ # Cell contents have already been updated at this point.
+ # Now we only need to worry about measuring column widths.
+ updated_columns = self._updated_cells.copy()
+ self._updated_cells.clear()
+ self._update_column_widths(updated_columns)
+
+ def refresh_coordinate(self, coordinate: Coordinate) -> None:
+ """Refresh the cell at a coordinate.
Args:
- row_index: Row index.
- column_index: Column index.
+ coordinate: The coordinate to refresh.
"""
- if row_index < 0 or column_index < 0:
+ if not self.is_valid_coordinate(coordinate):
return
- region = self._get_cell_region(row_index, column_index)
+ region = self._get_cell_region(coordinate)
self._refresh_region(region)
def refresh_row(self, row_index: int) -> None:
@@ -685,7 +1187,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args:
row_index: The index of the row to refresh.
"""
- if row_index < 0 or row_index >= len(self.rows):
+ if not self.is_valid_row_index(row_index):
return
region = self._get_row_region(row_index)
@@ -697,7 +1199,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Args:
column_index: The index of the column to refresh.
"""
- if column_index < 0 or column_index >= len(self.columns):
+ if not self.is_valid_column_index(column_index):
return
region = self._get_column_region(column_index)
@@ -712,8 +1214,72 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
region = region.translate(-self.scroll_offset)
self.refresh(region)
+ def is_valid_row_index(self, row_index: int) -> bool:
+ """Return a boolean indicating whether the row_index is within table bounds.
+
+ Args:
+ row_index: The row index to check.
+
+ Returns:
+ True if the row index is within the bounds of the table.
+ """
+ return 0 <= row_index < len(self.rows)
+
+ def is_valid_column_index(self, column_index: int) -> bool:
+ """Return a boolean indicating whether the column_index is within table bounds.
+
+ Args:
+ column_index: The column index to check.
+
+ Returns:
+ True if the column index is within the bounds of the table.
+ """
+ return 0 <= column_index < len(self.columns)
+
+ def is_valid_coordinate(self, coordinate: Coordinate) -> bool:
+ """Return a boolean indicating whether the given coordinate is valid.
+
+ Args:
+ coordinate: The coordinate to validate.
+
+ Returns:
+ True if the coordinate is within the bounds of the table.
+ """
+ row_index, column_index = coordinate
+ return self.is_valid_row_index(row_index) and self.is_valid_column_index(
+ column_index
+ )
+
+ @property
+ def ordered_columns(self) -> list[Column]:
+ """The list of Columns in the DataTable, ordered as they appear on screen."""
+ column_indices = range(len(self.columns))
+ column_keys = [
+ self._column_locations.get_key(index) for index in column_indices
+ ]
+ ordered_columns = [self.columns[key] for key in column_keys]
+ return ordered_columns
+
+ @property
+ def ordered_rows(self) -> list[Row]:
+ """The list of Rows in the DataTable, ordered as they appear on screen."""
+ num_rows = self.row_count
+ update_count = self._update_count
+ cache_key = (num_rows, update_count)
+ if cache_key in self._ordered_row_cache:
+ ordered_rows = self._ordered_row_cache[cache_key]
+ else:
+ row_indices = range(num_rows)
+ ordered_rows = []
+ for row_index in row_indices:
+ row_key = self._row_locations.get_key(row_index)
+ row = self.rows[row_key]
+ ordered_rows.append(row)
+ self._ordered_row_cache[cache_key] = ordered_rows
+ return ordered_rows
+
def _get_row_renderables(self, row_index: int) -> list[RenderableType]:
- """Get renderables for the given row.
+ """Get renderables for the row currently at the given row index.
Args:
row_index: Index of the row.
@@ -721,20 +1287,17 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Returns:
List of renderables
"""
-
+ ordered_columns = self.ordered_columns
if row_index == -1:
- row = [column.label for column in self.columns]
+ row: list[RenderableType] = [column.label for column in ordered_columns]
return row
- data = self.data.get(row_index)
+ ordered_row = self.get_row_at(row_index)
empty = Text()
- if data is None:
- return [empty for _ in self.columns]
- else:
- return [
- Text() if datum is None else default_cell_formatter(datum) or empty
- for datum, _ in zip_longest(data, range(len(self.columns)))
- ]
+ return [
+ Text() if datum is None else default_cell_formatter(datum) or empty
+ for datum, _ in zip_longest(ordered_row, range(len(self.columns)))
+ ]
def _render_cell(
self,
@@ -767,9 +1330,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
if hover and show_cursor and self._show_hover_cursor:
style += self.get_component_styles("datatable--highlight").rich_style
if is_fixed_style:
- # Apply subtle variation in style for the fixed (blue background by default)
- # rows and columns affected by the cursor, to ensure we can still differentiate
- # between the labels and the data.
+ # Apply subtle variation in style for the fixed (blue background by
+ # default) rows and columns affected by the cursor, to ensure we can
+ # still differentiate between the labels and the data.
style += self.get_component_styles(
"datatable--highlight-fixed"
).rich_style
@@ -779,34 +1342,39 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
if is_fixed_style:
style += self.get_component_styles("datatable--cursor-fixed").rich_style
- cell_key = (row_index, column_index, style, cursor, hover)
- if cell_key not in self._cell_render_cache:
+ if is_header_row:
+ row_key = self._header_row_key
+ else:
+ row_key = self._row_locations.get_key(row_index)
+
+ column_key = self._column_locations.get_key(column_index)
+ cell_cache_key = (row_key, column_key, style, cursor, hover, self._update_count)
+ if cell_cache_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column_index})
- height = (
- self.header_height if is_header_row else self.rows[row_index].height
- )
+ height = self.header_height if is_header_row else self.rows[row_key].height
cell = self._get_row_renderables(row_index)[column_index]
lines = self.app.console.render_lines(
Padding(cell, (0, 1)),
self.app.console.options.update_dimensions(width, height),
style=style,
)
- self._cell_render_cache[cell_key] = lines
- return self._cell_render_cache[cell_key]
+ self._cell_render_cache[cell_cache_key] = lines
+ return self._cell_render_cache[cell_cache_key]
- def _render_row(
+ def _render_line_in_row(
self,
- row_index: int,
+ row_key: RowKey,
line_no: int,
base_style: Style,
cursor_location: Coordinate,
hover_location: Coordinate,
) -> tuple[SegmentLines, SegmentLines]:
- """Render a row in to lines for each cell.
+ """Render a single line from a row in the DataTable.
Args:
- row_index: Index of the row.
- line_no: Line number (on screen, 0 is top)
+ row_key: The identifying key for this row.
+ line_no: Line number (y-coordinate) within row. 0 is the first strip of
+ cells in the row, line_no=1 is the next line in the row, and so on...
base_style: Base style of row.
cursor_location: The location of the cursor in the DataTable.
hover_location: The location of the hover cursor in the DataTable.
@@ -816,8 +1384,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
"""
cursor_type = self.cursor_type
show_cursor = self.show_cursor
+
cache_key = (
- row_index,
+ row_key,
line_no,
base_style,
cursor_location,
@@ -825,43 +1394,50 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
cursor_type,
show_cursor,
self._show_hover_cursor,
+ self._update_count,
)
if cache_key in self._row_render_cache:
return self._row_render_cache[cache_key]
- render_cell = self._render_cell
-
def _should_highlight(
- cursor_location: Coordinate,
- cell_location: Coordinate,
- cursor_type: CursorType,
+ cursor: Coordinate,
+ target_cell: Coordinate,
+ type_of_cursor: CursorType,
) -> bool:
"""Determine whether we should highlight a cell given the location
of the cursor, the location of the cell, and the type of cursor that
is currently active."""
- if cursor_type == "cell":
- return cursor_location == cell_location
- elif cursor_type == "row":
- cursor_row, _ = cursor_location
- cell_row, _ = cell_location
+ if type_of_cursor == "cell":
+ return cursor == target_cell
+ elif type_of_cursor == "row":
+ cursor_row, _ = cursor
+ cell_row, _ = target_cell
return cursor_row == cell_row
- elif cursor_type == "column":
- _, cursor_column = cursor_location
- _, cell_column = cell_location
+ elif type_of_cursor == "column":
+ _, cursor_column = cursor
+ _, cell_column = target_cell
return cursor_column == cell_column
else:
return False
+ if row_key in self._row_locations:
+ row_index = self._row_locations.get(row_key)
+ else:
+ row_index = -1
+
+ render_cell = self._render_cell
if self.fixed_columns:
fixed_style = self.get_component_styles("datatable--fixed").rich_style
fixed_style += Style.from_meta({"fixed": True})
fixed_row = []
- for column in self.columns[: self.fixed_columns]:
- cell_location = Coordinate(row_index, column.index)
+ for column_index, column in enumerate(
+ self.ordered_columns[: self.fixed_columns]
+ ):
+ cell_location = Coordinate(row_index, column_index)
fixed_cell_lines = render_cell(
row_index,
- column.index,
+ column_index,
fixed_style,
column.render_width,
cursor=_should_highlight(
@@ -873,7 +1449,8 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
else:
fixed_row = []
- if row_index == -1:
+ is_header_row = row_key is self._header_row_key
+ if is_header_row:
row_style = self.get_component_styles("datatable--header").rich_style
else:
if self.zebra_stripes:
@@ -885,11 +1462,11 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
row_style = base_style
scrollable_row = []
- for column in self.columns:
- cell_location = Coordinate(row_index, column.index)
+ for column_index, column in enumerate(self.ordered_columns):
+ cell_location = Coordinate(row_index, column_index)
cell_lines = render_cell(
row_index,
- column.index,
+ column_index,
row_style,
column.render_width,
cursor=_should_highlight(cursor_location, cell_location, cursor_type),
@@ -901,25 +1478,29 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._row_render_cache[cache_key] = row_pair
return row_pair
- def _get_offsets(self, y: int) -> tuple[int, int]:
- """Get row number and line offset for a given line.
+ def _get_offsets(self, y: int) -> tuple[RowKey, int]:
+ """Get row key and line offset for a given line.
Args:
- y: Y coordinate relative to screen top.
+ y: Y coordinate relative to DataTable top.
Returns:
- Line number and line offset within cell.
+ Row key and line (y) offset within cell.
"""
+ header_height = self.header_height
+ y_offsets = self._y_offsets
if self.show_header:
- if y < self.header_height:
- return (-1, y)
- y -= self.header_height
- if y > len(self._y_offsets):
+ if y < header_height:
+ return self._header_row_key, y
+ y -= header_height
+ if y > len(y_offsets):
raise LookupError("Y coord {y!r} is greater than total height")
- return self._y_offsets[y]
+
+ return y_offsets[y]
def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
- """Render a line in to a list of segments.
+ """Render a (possibly cropped) line in to a Strip (a list of segments
+ representing a horizontal line).
Args:
y: Y coordinate of line
@@ -928,13 +1509,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
base_style: Style to apply to line.
Returns:
- List of segments for rendering.
+ The Strip which represents this cropped line.
"""
width = self.size.width
try:
- row_index, line_no = self._get_offsets(y)
+ row_key, y_offset_in_row = self._get_offsets(y)
except LookupError:
return Strip.blank(width, base_style)
@@ -943,24 +1524,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
x1,
x2,
width,
- self.cursor_cell,
- self.hover_cell,
+ self.cursor_coordinate,
+ self.hover_coordinate,
base_style,
self.cursor_type,
self._show_hover_cursor,
+ self._update_count,
)
if cache_key in self._line_cache:
return self._line_cache[cache_key]
- fixed, scrollable = self._render_row(
- row_index,
- line_no,
+ fixed, scrollable = self._render_line_in_row(
+ row_key,
+ y_offset_in_row,
base_style,
- cursor_location=self.cursor_cell,
- hover_location=self.hover_cell,
+ cursor_location=self.cursor_coordinate,
+ hover_location=self.hover_coordinate,
)
fixed_width = sum(
- column.render_width for column in self.columns[: self.fixed_columns]
+ column.render_width for column in self.ordered_columns[: self.fixed_columns]
)
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
@@ -975,37 +1557,80 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def render_line(self, y: int) -> Strip:
width, height = self.size
scroll_x, scroll_y = self.scroll_offset
- fixed_top_row_count = sum(
- self.get_row_height(row_index) for row_index in range(self.fixed_rows)
+
+ fixed_row_keys: list[RowKey] = [
+ self._row_locations.get_key(row_index)
+ for row_index in range(self.fixed_rows)
+ ]
+
+ fixed_rows_height = sum(
+ self.get_row_height(row_key) for row_key in fixed_row_keys
)
if self.show_header:
- fixed_top_row_count += self.get_row_height(-1)
+ fixed_rows_height += self.get_row_height(self._header_row_key)
- if y >= fixed_top_row_count:
+ if y >= fixed_rows_height:
y += scroll_y
return self._render_line(y, scroll_x, scroll_x + width, self.rich_style)
def on_mouse_move(self, event: events.MouseMove):
+ """If the hover cursor is visible, display it by extracting the row
+ and column metadata from the segments present in the cells."""
self._set_hover_cursor(True)
meta = event.style.meta
if meta and self.show_cursor and self.cursor_type != "none":
try:
- self.hover_cell = Coordinate(meta["row"], meta["column"])
+ self.hover_coordinate = Coordinate(meta["row"], meta["column"])
except KeyError:
pass
def _get_fixed_offset(self) -> Spacing:
+ """Calculate the "fixed offset", that is the space to the top and left
+ that is occupied by fixed rows and columns respectively. Fixed rows and columns
+ are rows and columns that do not participate in scrolling."""
top = self.header_height if self.show_header else 0
top += sum(
- self.rows[row_index].height
+ self.rows[self._row_locations.get_key(row_index)].height
for row_index in range(self.fixed_rows)
if row_index in self.rows
)
- left = sum(column.render_width for column in self.columns[: self.fixed_columns])
+ left = sum(
+ column.render_width for column in self.ordered_columns[: self.fixed_columns]
+ )
return Spacing(top, 0, 0, left)
+ def sort(
+ self,
+ *columns: ColumnKey | str,
+ reverse: bool = False,
+ ) -> None:
+ """Sort the rows in the DataTable by one or more column keys.
+
+ Args:
+ columns: One or more columns to sort by the values in.
+ reverse: If True, the sort order will be reversed.
+ """
+
+ def sort_by_column_keys(
+ row: tuple[RowKey, dict[ColumnKey | str, CellType]]
+ ) -> Any:
+ _, row_data = row
+ result = itemgetter(*columns)(row_data)
+ return result
+
+ ordered_rows = sorted(
+ self._data.items(), key=sort_by_column_keys, reverse=reverse
+ )
+ self._row_locations = TwoWayDict(
+ {key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
+ )
+ self._update_count += 1
+ self.refresh()
+
def _scroll_cursor_into_view(self, animate: bool = False) -> None:
+ """When the cursor is at a boundary of the DataTable and moves out
+ of view, this method handles scrolling to ensure it remains visible."""
fixed_offset = self._get_fixed_offset()
top, _, _, left = fixed_offset
@@ -1016,7 +1641,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
x, y, width, height = self._get_column_region(self.cursor_column)
region = Region(x, int(self.scroll_y) + top, width, height - top)
else:
- region = self._get_cell_region(self.cursor_row, self.cursor_column)
+ region = self._get_cell_region(self.cursor_coordinate)
self.scroll_to_region(region, animate=animate, spacing=fixed_offset)
@@ -1036,24 +1661,36 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
elif cursor_type == "row":
self.refresh_row(self.hover_row)
elif cursor_type == "cell":
- self.refresh_cell(*self.hover_cell)
+ self.refresh_coordinate(self.hover_coordinate)
def on_click(self, event: events.Click) -> None:
self._set_hover_cursor(True)
- if self.show_cursor and self.cursor_type != "none":
+ meta = self.get_style_at(event.x, event.y).meta
+ if not meta:
+ return
+
+ row_index = meta["row"]
+ column_index = meta["column"]
+ is_header_click = self.show_header and row_index == -1
+ if is_header_click:
+ # Header clicks work even if cursor is off, and doesn't move the cursor.
+ column = self.ordered_columns[column_index]
+ message = DataTable.HeaderSelected(
+ self, column.key, column_index, label=column.label
+ )
+ self.post_message_no_wait(message)
+ elif self.show_cursor and self.cursor_type != "none":
# Only post selection events if there is a visible row/col/cell cursor.
+ self.cursor_coordinate = Coordinate(row_index, column_index)
self._post_selected_message()
- meta = self.get_style_at(event.x, event.y).meta
- if meta:
- self.cursor_cell = Coordinate(meta["row"], meta["column"])
- self._scroll_cursor_into_view(animate=True)
- event.stop()
+ self._scroll_cursor_into_view(animate=True)
+ event.stop()
def action_cursor_up(self) -> None:
self._set_hover_cursor(False)
cursor_type = self.cursor_type
if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"):
- self.cursor_cell = self.cursor_cell.up()
+ self.cursor_coordinate = self.cursor_coordinate.up()
self._scroll_cursor_into_view()
else:
# If the cursor doesn't move up (e.g. column cursor can't go up),
@@ -1064,7 +1701,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._set_hover_cursor(False)
cursor_type = self.cursor_type
if self.show_cursor and (cursor_type == "cell" or cursor_type == "row"):
- self.cursor_cell = self.cursor_cell.down()
+ self.cursor_coordinate = self.cursor_coordinate.down()
self._scroll_cursor_into_view()
else:
super().action_scroll_down()
@@ -1073,7 +1710,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._set_hover_cursor(False)
cursor_type = self.cursor_type
if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"):
- self.cursor_cell = self.cursor_cell.right()
+ self.cursor_coordinate = self.cursor_coordinate.right()
self._scroll_cursor_into_view(animate=True)
else:
super().action_scroll_right()
@@ -1082,7 +1719,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
self._set_hover_cursor(False)
cursor_type = self.cursor_type
if self.show_cursor and (cursor_type == "cell" or cursor_type == "column"):
- self.cursor_cell = self.cursor_cell.left()
+ self.cursor_coordinate = self.cursor_coordinate.left()
self._scroll_cursor_into_view(animate=True)
else:
super().action_scroll_left()
@@ -1094,19 +1731,25 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def _post_selected_message(self):
"""Post the appropriate message for a selection based on the `cursor_type`."""
- cursor_cell = self.cursor_cell
+ cursor_coordinate = self.cursor_coordinate
cursor_type = self.cursor_type
+ cell_key = self.coordinate_to_cell_key(cursor_coordinate)
if cursor_type == "cell":
self.post_message_no_wait(
DataTable.CellSelected(
self,
- self.get_cell_value(cursor_cell),
- cursor_cell,
+ self.get_cell_at(cursor_coordinate),
+ coordinate=cursor_coordinate,
+ cell_key=cell_key,
)
)
elif cursor_type == "row":
- row, _ = cursor_cell
- self.post_message_no_wait(DataTable.RowSelected(self, row))
+ row_index, _ = cursor_coordinate
+ row_key, _ = cell_key
+ self.post_message_no_wait(DataTable.RowSelected(self, row_index, row_key))
elif cursor_type == "column":
- _, column = cursor_cell
- self.post_message_no_wait(DataTable.ColumnSelected(self, column))
+ _, column_index = cursor_coordinate
+ _, column_key = cell_key
+ self.post_message_no_wait(
+ DataTable.ColumnSelected(self, column_index, column_key)
+ )
diff --git a/src/textual/widgets/data_table.py b/src/textual/widgets/data_table.py
index d0316f387..0bb18f87f 100644
--- a/src/textual/widgets/data_table.py
+++ b/src/textual/widgets/data_table.py
@@ -1,5 +1,29 @@
"""Make non-widget DataTable support classes available."""
-from ._data_table import Column, Row
+from ._data_table import (
+ CellDoesNotExist,
+ CellKey,
+ CellType,
+ Column,
+ ColumnDoesNotExist,
+ ColumnKey,
+ CursorType,
+ DuplicateKey,
+ Row,
+ RowDoesNotExist,
+ RowKey,
+)
-__all__ = ["Column", "Row"]
+__all__ = [
+ "CellDoesNotExist",
+ "CellKey",
+ "CellType",
+ "Column",
+ "ColumnDoesNotExist",
+ "ColumnKey",
+ "CursorType",
+ "DuplicateKey",
+ "Row",
+ "RowDoesNotExist",
+ "RowKey",
+]
diff --git a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
index 3d9516f10..6fd96bfe6 100644
--- a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
+++ b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
@@ -10173,133 +10173,133 @@
font-weight: 700;
}
- .terminal-121683423-matrix {
+ .terminal-1288566407-matrix {
font-family: Fira Code, monospace;
font-size: 20px;
line-height: 24.4px;
font-variant-east-asian: full-width;
}
- .terminal-121683423-title {
+ .terminal-1288566407-title {
font-size: 18px;
font-weight: bold;
font-family: arial;
}
- .terminal-121683423-r1 { fill: #dde6ed;font-weight: bold }
- .terminal-121683423-r2 { fill: #e1e1e1 }
- .terminal-121683423-r3 { fill: #c5c8c6 }
- .terminal-121683423-r4 { fill: #211505 }
+ .terminal-1288566407-r1 { fill: #dde6ed;font-weight: bold }
+ .terminal-1288566407-r2 { fill: #e1e1e1 }
+ .terminal-1288566407-r3 { fill: #c5c8c6 }
+ .terminal-1288566407-r4 { fill: #211505 }
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
- TableApp
+ TableApp
-
-
-
- lane swimmer country time
- 4 Joseph Schooling Singapore 50.39
- 2 Michael Phelps United States 51.14
- 5 Chad le Clos South Africa 51.14
- 6 László Cseh Hungary 51.14
- 3 Li Zhuhao China 51.26
- 8 Mehdy Metella France 51.58
- 7 Tom Shields United States 51.73
- 1 Aleksandr Sadovnikov Russia 51.84
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+ lane swimmer country time
+ 4 Joseph Schooling Singapore 50.39
+ 2 Michael Phelps United States 51.14
+ 5 Chad le Clos South Africa 51.14
+ 6 László Cseh Hungary 51.14
+ 3 Li Zhuhao China 51.26
+ 8 Mehdy Metella France 51.58
+ 7 Tom Shields United States 51.73
+ 1 Aleksandr Sadovnikov Russia 51.84
+ 10 Darren Burns Scotland 51.84
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -10464,6 +10464,163 @@
'''
# ---
+# name: test_datatable_sort_multikey
+ '''
+
+
+ '''
+# ---
# name: test_demo
'''