mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
DataTable sort by function (or other callable) (#3090)
* DataTable sort by function (or other callable) The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). Covers #2261 using [suggested function signature](https://github.com/Textualize/textual/pull/2512#issuecomment-1580277771) from @darrenburns on PR #2512. * argument change and functionaloty update Changed back to orinal `columns` argument and added a new `key` argument which takes a function (or other callable). This allows the PR to NOT BE a breaking change. * better example for docs - Updated the example file for the docs to better show the functionality of the change (especially when using `columns` and `key` together). - Added one new tests to cover a similar situation to the example changes * removed unecessary code from example - the sort by clicked column function was bloat in my opinion * requested changes * simplify method and terminology * combine key_wrapper and default sort * Removing some tests from DataTable.sort as duplicates. Ensure there is test coverage of the case where a key, but no columns, is passed to DataTable.sort. * Remove unused import * Fix merge issues in CHANGELOG, update DataTable sort-by-key changelog PR link --------- Co-authored-by: Darren Burns <darrenburns@users.noreply.github.com> Co-authored-by: Darren Burns <darrenb900@gmail.com>
This commit is contained in:
15
CHANGELOG.md
15
CHANGELOG.md
@@ -27,6 +27,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
|
|||||||
|
|
||||||
- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
|
- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
|
||||||
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
|
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
|
||||||
|
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090
|
||||||
|
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
|
||||||
|
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
|
||||||
|
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
@@ -49,15 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
|
|||||||
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
|
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
|
||||||
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586
|
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
|
|
||||||
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498
|
|
||||||
|
|
||||||
## [0.40.0] - 2023-10-11
|
## [0.40.0] - 2023-10-11
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
@@ -251,7 +247,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
|
|||||||
|
|
||||||
- DescendantBlur and DescendantFocus can now be used with @on decorator
|
- DescendantBlur and DescendantFocus can now be used with @on decorator
|
||||||
|
|
||||||
|
|
||||||
## [0.32.0] - 2023-08-03
|
## [0.32.0] - 2023-08-03
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
92
docs/examples/widgets/data_table_sort.py
Normal file
92
docs/examples/widgets/data_table_sort.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from textual.app import App, ComposeResult
|
||||||
|
from textual.widgets import DataTable, Footer
|
||||||
|
|
||||||
|
ROWS = [
|
||||||
|
("lane", "swimmer", "country", "time 1", "time 2"),
|
||||||
|
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
|
||||||
|
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
|
||||||
|
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
|
||||||
|
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
|
||||||
|
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
|
||||||
|
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
|
||||||
|
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
|
||||||
|
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
|
||||||
|
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TableApp(App):
|
||||||
|
BINDINGS = [
|
||||||
|
("a", "sort_by_average_time", "Sort By Average Time"),
|
||||||
|
("n", "sort_by_last_name", "Sort By Last Name"),
|
||||||
|
("c", "sort_by_country", "Sort By Country"),
|
||||||
|
("d", "sort_by_columns", "Sort By Columns (Only)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
current_sorts: set = set()
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
yield DataTable()
|
||||||
|
yield Footer()
|
||||||
|
|
||||||
|
def on_mount(self) -> None:
|
||||||
|
table = self.query_one(DataTable)
|
||||||
|
for col in ROWS[0]:
|
||||||
|
table.add_column(col, key=col)
|
||||||
|
table.add_rows(ROWS[1:])
|
||||||
|
|
||||||
|
def sort_reverse(self, sort_type: str):
|
||||||
|
"""Determine if `sort_type` is ascending or descending."""
|
||||||
|
reverse = sort_type in self.current_sorts
|
||||||
|
if reverse:
|
||||||
|
self.current_sorts.remove(sort_type)
|
||||||
|
else:
|
||||||
|
self.current_sorts.add(sort_type)
|
||||||
|
return reverse
|
||||||
|
|
||||||
|
def action_sort_by_average_time(self) -> None:
|
||||||
|
"""Sort DataTable by average of times (via a function) and
|
||||||
|
passing of column data through positional arguments."""
|
||||||
|
|
||||||
|
def sort_by_average_time_then_last_name(row_data):
|
||||||
|
name, *scores = row_data
|
||||||
|
return (sum(scores) / len(scores), name.split()[-1])
|
||||||
|
|
||||||
|
table = self.query_one(DataTable)
|
||||||
|
table.sort(
|
||||||
|
"swimmer",
|
||||||
|
"time 1",
|
||||||
|
"time 2",
|
||||||
|
key=sort_by_average_time_then_last_name,
|
||||||
|
reverse=self.sort_reverse("time"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def action_sort_by_last_name(self) -> None:
|
||||||
|
"""Sort DataTable by last name of swimmer (via a lambda)."""
|
||||||
|
table = self.query_one(DataTable)
|
||||||
|
table.sort(
|
||||||
|
"swimmer",
|
||||||
|
key=lambda swimmer: swimmer.split()[-1],
|
||||||
|
reverse=self.sort_reverse("swimmer"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def action_sort_by_country(self) -> None:
|
||||||
|
"""Sort DataTable by country which is a `Rich.Text` object."""
|
||||||
|
table = self.query_one(DataTable)
|
||||||
|
table.sort(
|
||||||
|
"country",
|
||||||
|
key=lambda country: country.plain,
|
||||||
|
reverse=self.sort_reverse("country"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def action_sort_by_columns(self) -> None:
|
||||||
|
"""Sort DataTable without a key."""
|
||||||
|
table = self.query_one(DataTable)
|
||||||
|
table.sort("swimmer", "lane", reverse=self.sort_reverse("columns"))
|
||||||
|
|
||||||
|
|
||||||
|
app = TableApp()
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
@@ -143,11 +143,22 @@ visible as you scroll through the data table.
|
|||||||
|
|
||||||
### Sorting
|
### Sorting
|
||||||
|
|
||||||
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
|
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.
|
||||||
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
|
|
||||||
when you added it.
|
Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.
|
||||||
You can then pass this key to the `sort` method to sort by that column.
|
|
||||||
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
|
Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.
|
||||||
|
|
||||||
|
=== "Output"
|
||||||
|
|
||||||
|
```{.textual path="docs/examples/widgets/data_table_sort.py"}
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "data_table_sort.py"
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "docs/examples/widgets/data_table_sort.py"
|
||||||
|
```
|
||||||
|
|
||||||
### Labelled rows
|
### Labelled rows
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import functools
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import chain, zip_longest
|
from itertools import chain, zip_longest
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
|
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
|
||||||
|
|
||||||
import rich.repr
|
import rich.repr
|
||||||
from rich.console import RenderableType
|
from rich.console import RenderableType
|
||||||
@@ -2348,30 +2348,40 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
|
|||||||
def sort(
|
def sort(
|
||||||
self,
|
self,
|
||||||
*columns: ColumnKey | str,
|
*columns: ColumnKey | str,
|
||||||
|
key: Callable[[Any], Any] | None = None,
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Sort the rows in the `DataTable` by one or more column keys.
|
"""Sort the rows in the `DataTable` by one or more column keys or a
|
||||||
|
key function (or other callable). If both columns and a key function
|
||||||
|
are specified, only data from those columns will sent to the key function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
columns: One or more columns to sort by the values in.
|
columns: One or more columns to sort by the values in.
|
||||||
|
key: A function (or other callable) that returns a key to
|
||||||
|
use for sorting purposes.
|
||||||
reverse: If True, the sort order will be reversed.
|
reverse: If True, the sort order will be reversed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The `DataTable` instance.
|
The `DataTable` instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sort_by_column_keys(
|
def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
|
||||||
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
|
|
||||||
) -> Any:
|
|
||||||
_, row_data = row
|
_, row_data = row
|
||||||
result = itemgetter(*columns)(row_data)
|
if columns:
|
||||||
|
result = itemgetter(*columns)(row_data)
|
||||||
|
else:
|
||||||
|
result = tuple(row_data.values())
|
||||||
|
if key is not None:
|
||||||
|
return key(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
ordered_rows = sorted(
|
ordered_rows = sorted(
|
||||||
self._data.items(), key=sort_by_column_keys, reverse=reverse
|
self._data.items(),
|
||||||
|
key=key_wrapper,
|
||||||
|
reverse=reverse,
|
||||||
)
|
)
|
||||||
self._row_locations = TwoWayDict(
|
self._row_locations = TwoWayDict(
|
||||||
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
|
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
|
||||||
)
|
)
|
||||||
self._update_count += 1
|
self._update_count += 1
|
||||||
self.refresh()
|
self.refresh()
|
||||||
|
|||||||
@@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():
|
|||||||
assert not table._show_hover_cursor
|
assert not table._show_hover_cursor
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sort_by_all_columns_no_key():
|
||||||
|
"""Test sorting a `DataTable` by all columns."""
|
||||||
|
|
||||||
|
app = DataTableApp()
|
||||||
|
async with app.run_test():
|
||||||
|
table = app.query_one(DataTable)
|
||||||
|
a, b, c = table.add_columns("A", "B", "C")
|
||||||
|
table.add_row(1, 3, 8)
|
||||||
|
table.add_row(2, 9, 5)
|
||||||
|
table.add_row(1, 1, 9)
|
||||||
|
assert table.get_row_at(0) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(1) == [2, 9, 5]
|
||||||
|
assert table.get_row_at(2) == [1, 1, 9]
|
||||||
|
|
||||||
|
table.sort()
|
||||||
|
assert table.get_row_at(0) == [1, 1, 9]
|
||||||
|
assert table.get_row_at(1) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(2) == [2, 9, 5]
|
||||||
|
|
||||||
|
table.sort(reverse=True)
|
||||||
|
assert table.get_row_at(0) == [2, 9, 5]
|
||||||
|
assert table.get_row_at(1) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(2) == [1, 1, 9]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sort_by_multiple_columns_no_key():
|
||||||
|
"""Test sorting a `DataTable` by multiple columns."""
|
||||||
|
|
||||||
|
app = DataTableApp()
|
||||||
|
async with app.run_test():
|
||||||
|
table = app.query_one(DataTable)
|
||||||
|
a, b, c = table.add_columns("A", "B", "C")
|
||||||
|
table.add_row(1, 3, 8)
|
||||||
|
table.add_row(2, 9, 5)
|
||||||
|
table.add_row(1, 1, 9)
|
||||||
|
|
||||||
|
table.sort(a, b, c)
|
||||||
|
assert table.get_row_at(0) == [1, 1, 9]
|
||||||
|
assert table.get_row_at(1) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(2) == [2, 9, 5]
|
||||||
|
|
||||||
|
table.sort(a, c, b)
|
||||||
|
assert table.get_row_at(0) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(1) == [1, 1, 9]
|
||||||
|
assert table.get_row_at(2) == [2, 9, 5]
|
||||||
|
|
||||||
|
table.sort(c, a, b, reverse=True)
|
||||||
|
assert table.get_row_at(0) == [1, 1, 9]
|
||||||
|
assert table.get_row_at(1) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(2) == [2, 9, 5]
|
||||||
|
|
||||||
|
table.sort(a, c)
|
||||||
|
assert table.get_row_at(0) == [1, 3, 8]
|
||||||
|
assert table.get_row_at(1) == [1, 1, 9]
|
||||||
|
assert table.get_row_at(2) == [2, 9, 5]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sort_by_function_sum():
|
||||||
|
"""Test sorting a `DataTable` using a custom sort function."""
|
||||||
|
|
||||||
|
def custom_sort(row_data):
|
||||||
|
return sum(row_data)
|
||||||
|
|
||||||
|
row_data = (
|
||||||
|
[1, 3, 8], # SUM=12
|
||||||
|
[2, 9, 5], # SUM=16
|
||||||
|
[1, 1, 9], # SUM=11
|
||||||
|
)
|
||||||
|
|
||||||
|
app = DataTableApp()
|
||||||
|
async with app.run_test():
|
||||||
|
table = app.query_one(DataTable)
|
||||||
|
a, b, c = table.add_columns("A", "B", "C")
|
||||||
|
for i, row in enumerate(row_data):
|
||||||
|
table.add_row(*row)
|
||||||
|
|
||||||
|
# Sorting by all columns
|
||||||
|
table.sort(a, b, c, key=custom_sort)
|
||||||
|
sorted_row_data = sorted(row_data, key=sum)
|
||||||
|
for i, row in enumerate(sorted_row_data):
|
||||||
|
assert table.get_row_at(i) == row
|
||||||
|
|
||||||
|
# Passing a sort function but no columns also sorts by all columns
|
||||||
|
table.sort(key=custom_sort)
|
||||||
|
sorted_row_data = sorted(row_data, key=sum)
|
||||||
|
for i, row in enumerate(sorted_row_data):
|
||||||
|
assert table.get_row_at(i) == row
|
||||||
|
|
||||||
|
table.sort(a, b, c, key=custom_sort, reverse=True)
|
||||||
|
sorted_row_data = sorted(row_data, key=sum, reverse=True)
|
||||||
|
for i, row in enumerate(sorted_row_data):
|
||||||
|
assert table.get_row_at(i) == row
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["cell", "height"],
|
["cell", "height"],
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user