Skip to content

Commit

Permalink
DataTable sort by function (or other callable) (#3090)
Browse files Browse the repository at this point in the history
* DataTable sort by function (or other callable)

The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). Covers #2261 using [suggested function signature](#2512 (comment)) from @darrenburns on PR #2512.

* argument change and functionaloty update

Changed back to orinal `columns` argument and added a new `key` argument
which takes a function (or other callable). This allows the PR to NOT BE
a breaking change.

* better example for docs

- Updated the example file for the docs to better show the functionality
of the change (especially when using `columns` and `key` together).
- Added one new tests to cover a similar situation to the example
  changes

* removed unecessary code from example

- the sort by clicked column function was bloat in my opinion

* requested changes

* simplify method and terminology

* combine key_wrapper and default sort

* Removing some tests from DataTable.sort as duplicates. Ensure there is test coverage of the case where a key, but no columns, is passed to DataTable.sort.

* Remove unused import

* Fix merge issues in CHANGELOG, update DataTable sort-by-key changelog PR link

---------

Co-authored-by: Darren Burns <[email protected]>
Co-authored-by: Darren Burns <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2023
1 parent 665dca9 commit 4f95d30
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 23 deletions.
15 changes: 5 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498


### Changed

Expand All @@ -49,15 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586

### Added

- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571

### Added

- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498

## [0.40.0] - 2023-10-11

### Added
Expand Down Expand Up @@ -251,7 +247,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- DescendantBlur and DescendantFocus can now be used with @on decorator


## [0.32.0] - 2023-08-03

### Added
Expand Down
92 changes: 92 additions & 0 deletions docs/examples/widgets/data_table_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from rich.text import Text

from textual.app import App, ComposeResult
from textual.widgets import DataTable, Footer

ROWS = [
("lane", "swimmer", "country", "time 1", "time 2"),
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
]


class TableApp(App):
BINDINGS = [
("a", "sort_by_average_time", "Sort By Average Time"),
("n", "sort_by_last_name", "Sort By Last Name"),
("c", "sort_by_country", "Sort By Country"),
("d", "sort_by_columns", "Sort By Columns (Only)"),
]

current_sorts: set = set()

def compose(self) -> ComposeResult:
yield DataTable()
yield Footer()

def on_mount(self) -> None:
table = self.query_one(DataTable)
for col in ROWS[0]:
table.add_column(col, key=col)
table.add_rows(ROWS[1:])

def sort_reverse(self, sort_type: str):
"""Determine if `sort_type` is ascending or descending."""
reverse = sort_type in self.current_sorts
if reverse:
self.current_sorts.remove(sort_type)
else:
self.current_sorts.add(sort_type)
return reverse

def action_sort_by_average_time(self) -> None:
"""Sort DataTable by average of times (via a function) and
passing of column data through positional arguments."""

def sort_by_average_time_then_last_name(row_data):
name, *scores = row_data
return (sum(scores) / len(scores), name.split()[-1])

table = self.query_one(DataTable)
table.sort(
"swimmer",
"time 1",
"time 2",
key=sort_by_average_time_then_last_name,
reverse=self.sort_reverse("time"),
)

def action_sort_by_last_name(self) -> None:
"""Sort DataTable by last name of swimmer (via a lambda)."""
table = self.query_one(DataTable)
table.sort(
"swimmer",
key=lambda swimmer: swimmer.split()[-1],
reverse=self.sort_reverse("swimmer"),
)

def action_sort_by_country(self) -> None:
"""Sort DataTable by country which is a `Rich.Text` object."""
table = self.query_one(DataTable)
table.sort(
"country",
key=lambda country: country.plain,
reverse=self.sort_reverse("country"),
)

def action_sort_by_columns(self) -> None:
"""Sort DataTable without a key."""
table = self.query_one(DataTable)
table.sort("swimmer", "lane", reverse=self.sort_reverse("columns"))


app = TableApp()
if __name__ == "__main__":
app.run()
21 changes: 16 additions & 5 deletions docs/widgets/data_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,22 @@ visible as you scroll through the data table.

### Sorting

The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
when you added it.
You can then pass this key to the `sort` method to sort by that column.
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.

Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.

Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.

=== "Output"

```{.textual path="docs/examples/widgets/data_table_sort.py"}
```

=== "data_table_sort.py"

```python
--8<-- "docs/examples/widgets/data_table_sort.py"
```

### Labelled rows

Expand Down
26 changes: 18 additions & 8 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from itertools import chain, zip_longest
from operator import itemgetter
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast

import rich.repr
from rich.console import RenderableType
Expand Down Expand Up @@ -2348,30 +2348,40 @@ def _get_fixed_offset(self) -> Spacing:
def sort(
self,
*columns: ColumnKey | str,
key: Callable[[Any], Any] | None = None,
reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys.
"""Sort the rows in the `DataTable` by one or more column keys or a
key function (or other callable). If both columns and a key function
are specified, only data from those columns will sent to the key function.
Args:
columns: One or more columns to sort by the values in.
key: A function (or other callable) that returns a key to
use for sorting purposes.
reverse: If True, the sort order will be reversed.
Returns:
The `DataTable` instance.
"""

def sort_by_column_keys(
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
) -> Any:
def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
_, row_data = row
result = itemgetter(*columns)(row_data)
if columns:
result = itemgetter(*columns)(row_data)
else:
result = tuple(row_data.values())
if key is not None:
return key(result)
return result

ordered_rows = sorted(
self._data.items(), key=sort_by_column_keys, reverse=reverse
self._data.items(),
key=key_wrapper,
reverse=reverse,
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
self.refresh()
Expand Down
94 changes: 94 additions & 0 deletions tests/test_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():
assert not table._show_hover_cursor


async def test_sort_by_all_columns_no_key():
"""Test sorting a `DataTable` by all columns."""

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [2, 9, 5]
assert table.get_row_at(2) == [1, 1, 9]

table.sort()
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(reverse=True)
assert table.get_row_at(0) == [2, 9, 5]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [1, 1, 9]


async def test_sort_by_multiple_columns_no_key():
"""Test sorting a `DataTable` by multiple columns."""

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)

table.sort(a, b, c)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(a, c, b)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(c, a, b, reverse=True)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(a, c)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]


async def test_sort_by_function_sum():
"""Test sorting a `DataTable` using a custom sort function."""

def custom_sort(row_data):
return sum(row_data)

row_data = (
[1, 3, 8], # SUM=12
[2, 9, 5], # SUM=16
[1, 1, 9], # SUM=11
)

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
for i, row in enumerate(row_data):
table.add_row(*row)

# Sorting by all columns
table.sort(a, b, c, key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

# Passing a sort function but no columns also sorts by all columns
table.sort(key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(a, b, c, key=custom_sort, reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row


@pytest.mark.parametrize(
["cell", "height"],
[
Expand Down

0 comments on commit 4f95d30

Please sign in to comment.