Skip to content

Commit

Permalink
Allow sort by related model field
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee committed Oct 23, 2023
1 parent f01d865 commit 4b90dd3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
11 changes: 9 additions & 2 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,10 +1064,17 @@ def sort_query(self, stmt: Select, request: Request) -> Select:
sort_fields = self._get_default_sort()

for sort_field, is_desc in sort_fields:
model = self.model

parts = sort_field.split(".")
for part in parts[:-1]:
model = getattr(model, part).mapper.class_
stmt = stmt.join(model)

if is_desc:
stmt = stmt.order_by(desc(sort_field))
stmt = stmt.order_by(desc(getattr(model, parts[-1])))
else:
stmt = stmt.order_by(asc(sort_field))
stmt = stmt.order_by(asc(getattr(model, parts[-1])))

return stmt

Expand Down
39 changes: 29 additions & 10 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from markupsafe import Markup
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.sql.expression import Select
Expand Down Expand Up @@ -51,7 +51,7 @@ def name_with_id(self) -> str:
class Address(Base):
__tablename__ = "addresses"

pk = Column(Integer, primary_key=True)
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))

user = relationship("User", back_populates="addresses")
Expand Down Expand Up @@ -121,9 +121,9 @@ class UserAdmin(ModelView, model=User):

def test_column_list_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_list = ["pk", "user_id"]
column_list = ["id", "user_id"]

assert AddressAdmin().get_list_columns() == ["pk", "user_id"]
assert AddressAdmin().get_list_columns() == ["id", "user_id"]


def test_column_list_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -242,9 +242,9 @@ class UserAdmin(ModelView, model=User):

def test_form_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
form_columns = ["pk", "user_id"]
form_columns = ["id", "user_id"]

assert AddressAdmin().get_form_columns() == ["pk", "user_id"]
assert AddressAdmin().get_form_columns() == ["id", "user_id"]


def test_form_columns_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -299,9 +299,9 @@ class UserAdmin(ModelView, model=User):

def test_export_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_export_list = ["pk", "user_id"]
column_export_list = ["id", "user_id"]

assert AddressAdmin().get_export_columns() == ["pk", "user_id"]
assert AddressAdmin().get_export_columns() == ["id", "user_id"]


def test_export_columns_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -386,8 +386,8 @@ class AddressAdmin(ModelView, model=Address):
column_list = "__all__"
column_details_list = "__all__"

assert AddressAdmin().get_list_columns() == ["user", "pk", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "pk", "user_id"]
assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"]


async def test_get_prop_value() -> None:
Expand Down Expand Up @@ -415,3 +415,22 @@ class UserAdmin(ModelView, model=User):
assert UserAdmin().get_list_columns() == ["id", "name", "name_with_id"]
assert UserAdmin().get_details_columns() == ["addresses", "profile", "id", "name"]
assert await UserAdmin().get_prop_value(user, "name_with_id") == "batman - 1"


def test_sort_query() -> None:
class AddressAdmin(ModelView, model=Address):
...

query = select(Address)

request = Request({"type": "http", "query_string": "sortBy=id&sort=asc"})
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY addresses.id ASC" in str(stmt)

request = Request({"type": "http", "query_string": b"sortBy=user.name&sort=desc"})
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY users.name DESC" in str(stmt)

request = Request({"type": "http", "query_string": b"sortBy=user.profile.role&sort=desc"})
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY profiles.role DESC" in str(stmt)

0 comments on commit 4b90dd3

Please sign in to comment.