DataTable sort by function (or other callable) (#3090)

* DataTable sort by function (or other callable)

The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). Covers #2261 using [suggested function signature](https://github.com/Textualize/textual/pull/2512#issuecomment-1580277771) from @darrenburns on PR #2512.

* argument change and functionaloty update

Changed back to orinal `columns` argument and added a new `key` argument
which takes a function (or other callable). This allows the PR to NOT BE
a breaking change.

* better example for docs

- Updated the example file for the docs to better show the functionality
of the change (especially when using `columns` and `key` together).
- Added one new tests to cover a similar situation to the example
  changes

* removed unecessary code from example

- the sort by clicked column function was bloat in my opinion

* requested changes

* simplify method and terminology

* combine key_wrapper and default sort

* Removing some tests from DataTable.sort as duplicates. Ensure there is test coverage of the case where a key, but no columns, is passed to DataTable.sort.

* Remove unused import

* Fix merge issues in CHANGELOG, update DataTable sort-by-key changelog PR link

---------

Co-authored-by: Darren Burns <darrenburns@users.noreply.github.com>
Co-authored-by: Darren Burns <darrenb900@gmail.com>
This commit is contained in:
Josh Duncan
2023-10-31 09:14:47 -04:00
committed by GitHub
parent 665dca9cb8
commit 4f95d30619
5 changed files with 225 additions and 23 deletions

View File

@@ -27,6 +27,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
### Changed
@@ -49,15 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586
### Added
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
### Added
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
## [0.40.0] - 2023-10-11
### Added
@@ -251,7 +247,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- DescendantBlur and DescendantFocus can now be used with @on decorator
## [0.32.0] - 2023-08-03
### Added

View File

@@ -0,0 +1,92 @@
from rich.text import Text
from textual.app import App, ComposeResult
from textual.widgets import DataTable, Footer
ROWS = [
("lane", "swimmer", "country", "time 1", "time 2"),
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
]
class TableApp(App):
BINDINGS = [
("a", "sort_by_average_time", "Sort By Average Time"),
("n", "sort_by_last_name", "Sort By Last Name"),
("c", "sort_by_country", "Sort By Country"),
("d", "sort_by_columns", "Sort By Columns (Only)"),
]
current_sorts: set = set()
def compose(self) -> ComposeResult:
yield DataTable()
yield Footer()
def on_mount(self) -> None:
table = self.query_one(DataTable)
for col in ROWS[0]:
table.add_column(col, key=col)
table.add_rows(ROWS[1:])
def sort_reverse(self, sort_type: str):
"""Determine if `sort_type` is ascending or descending."""
reverse = sort_type in self.current_sorts
if reverse:
self.current_sorts.remove(sort_type)
else:
self.current_sorts.add(sort_type)
return reverse
def action_sort_by_average_time(self) -> None:
"""Sort DataTable by average of times (via a function) and
passing of column data through positional arguments."""
def sort_by_average_time_then_last_name(row_data):
name, *scores = row_data
return (sum(scores) / len(scores), name.split()[-1])
table = self.query_one(DataTable)
table.sort(
"swimmer",
"time 1",
"time 2",
key=sort_by_average_time_then_last_name,
reverse=self.sort_reverse("time"),
)
def action_sort_by_last_name(self) -> None:
"""Sort DataTable by last name of swimmer (via a lambda)."""
table = self.query_one(DataTable)
table.sort(
"swimmer",
key=lambda swimmer: swimmer.split()[-1],
reverse=self.sort_reverse("swimmer"),
)
def action_sort_by_country(self) -> None:
"""Sort DataTable by country which is a `Rich.Text` object."""
table = self.query_one(DataTable)
table.sort(
"country",
key=lambda country: country.plain,
reverse=self.sort_reverse("country"),
)
def action_sort_by_columns(self) -> None:
"""Sort DataTable without a key."""
table = self.query_one(DataTable)
table.sort("swimmer", "lane", reverse=self.sort_reverse("columns"))
app = TableApp()
if __name__ == "__main__":
app.run()

View File

@@ -143,11 +143,22 @@ visible as you scroll through the data table.
### Sorting
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
when you added it.
You can then pass this key to the `sort` method to sort by that column.
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.
Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.
Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.
=== "Output"
```{.textual path="docs/examples/widgets/data_table_sort.py"}
```
=== "data_table_sort.py"
```python
--8<-- "docs/examples/widgets/data_table_sort.py"
```
### Labelled rows

View File

@@ -4,7 +4,7 @@ import functools
from dataclasses import dataclass
from itertools import chain, zip_longest
from operator import itemgetter
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
import rich.repr
from rich.console import RenderableType
@@ -2348,30 +2348,40 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
def sort(
self,
*columns: ColumnKey | str,
key: Callable[[Any], Any] | None = None,
reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys.
"""Sort the rows in the `DataTable` by one or more column keys or a
key function (or other callable). If both columns and a key function
are specified, only data from those columns will sent to the key function.
Args:
columns: One or more columns to sort by the values in.
key: A function (or other callable) that returns a key to
use for sorting purposes.
reverse: If True, the sort order will be reversed.
Returns:
The `DataTable` instance.
"""
def sort_by_column_keys(
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
) -> Any:
def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
_, row_data = row
result = itemgetter(*columns)(row_data)
if columns:
result = itemgetter(*columns)(row_data)
else:
result = tuple(row_data.values())
if key is not None:
return key(result)
return result
ordered_rows = sorted(
self._data.items(), key=sort_by_column_keys, reverse=reverse
self._data.items(),
key=key_wrapper,
reverse=reverse,
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
self.refresh()

View File

@@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():
assert not table._show_hover_cursor
async def test_sort_by_all_columns_no_key():
"""Test sorting a `DataTable` by all columns."""
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [2, 9, 5]
assert table.get_row_at(2) == [1, 1, 9]
table.sort()
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]
table.sort(reverse=True)
assert table.get_row_at(0) == [2, 9, 5]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [1, 1, 9]
async def test_sort_by_multiple_columns_no_key():
"""Test sorting a `DataTable` by multiple columns."""
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)
table.sort(a, b, c)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]
table.sort(a, c, b)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]
table.sort(c, a, b, reverse=True)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]
table.sort(a, c)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]
async def test_sort_by_function_sum():
"""Test sorting a `DataTable` using a custom sort function."""
def custom_sort(row_data):
return sum(row_data)
row_data = (
[1, 3, 8], # SUM=12
[2, 9, 5], # SUM=16
[1, 1, 9], # SUM=11
)
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
for i, row in enumerate(row_data):
table.add_row(*row)
# Sorting by all columns
table.sort(a, b, c, key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
# Passing a sort function but no columns also sorts by all columns
table.sort(key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
table.sort(a, b, c, key=custom_sort, reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
@pytest.mark.parametrize(
["cell", "height"],
[