diff --git a/sqladmin/models.py b/sqladmin/models.py index 37ad0405..95993a84 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -900,19 +900,15 @@ def _stmt_by_identifier(self, identifier: str) -> Select: return stmt.where(*conditions) def get_prop_value( - self, obj: type, prop: Union[Column, ColumnProperty, RelationshipProperty] + self, obj: Any, prop: Union[Column, ColumnProperty, RelationshipProperty] ) -> Any: - result = None - - if isinstance(prop, Column): - result = getattr(obj, prop.name) - else: - result = getattr(obj, prop.key) - result = result.value if isinstance(result, Enum) else result + result = getattr(obj, prop.key, None) + if result and isinstance(result, Enum): + result = result.name return result - def get_list_value(self, obj: type, prop: MODEL_PROPERTY) -> Tuple[Any, Any]: + def get_list_value(self, obj: Any, prop: MODEL_PROPERTY) -> Tuple[Any, Any]: """Get tuple of (value, formatted_value) for the list view.""" value = self.get_prop_value(obj, prop) formatted_value = self._default_formatter(value) @@ -922,7 +918,7 @@ def get_list_value(self, obj: type, prop: MODEL_PROPERTY) -> Tuple[Any, Any]: formatted_value = formatter(obj, prop) return value, formatted_value - def get_detail_value(self, obj: type, prop: MODEL_PROPERTY) -> Tuple[Any, Any]: + def get_detail_value(self, obj: Any, prop: MODEL_PROPERTY) -> Tuple[Any, Any]: """Get tuple of (value, formatted_value) for the detail view.""" value = self.get_prop_value(obj, prop) formatted_value = self._default_formatter(value) diff --git a/tests/test_models.py b/tests/test_models.py index de52c7f8..d153e4c5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,9 +1,10 @@ +import enum from typing import Generator from unittest.mock import Mock, call, patch import pytest from markupsafe import Markup -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, select +from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker @@ -25,6 +26,16 @@ admin = Admin(app=app, engine=engine) +class Status(enum.Enum): + ACTIVE = "ACTIVE" + DEACTIVE = "DEACTIVE" + + +class Role(int, enum.Enum): + ADMIN = 1 + USER = 2 + + class User(Base): __tablename__ = "users" @@ -49,6 +60,8 @@ class Profile(Base): id = Column(Integer, primary_key=True) is_active = Column(Boolean) + role = Column(Enum(Role)) + status = Column(Enum(Status)) user_id = Column(Integer, ForeignKey("users.id"), unique=True) user = relationship("User", back_populates="profile") @@ -506,3 +519,14 @@ class AddressAdmin(ModelView, model=Address): assert AddressAdmin().get_list_columns() == all_columns assert AddressAdmin().get_details_columns() == all_columns + + +def test_get_prop_value() -> None: + class ProfileAdmin(ModelView, model=Profile): + ... + + profile = Profile(is_active=True, role=Role.ADMIN, status=Status.ACTIVE) + + assert ProfileAdmin().get_prop_value(profile, Profile.is_active) is True + assert ProfileAdmin().get_prop_value(profile, Profile.role) == "ADMIN" + assert ProfileAdmin().get_prop_value(profile, Profile.status) == "ACTIVE"