Skip to content

Add form-specific functionality to ModelAdmin #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api_reference/model_admin.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@
- column_searchable_list
- search_placeholder
- column_sortable_list
- form
- form_base_class
- form_args
- form_columns
- form_excluded_columns
- form_overrides
60 changes: 42 additions & 18 deletions sqladmin/forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from enum import Enum
from typing import Any, Callable, Dict, Sequence, Type, Union, no_type_check
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union, no_type_check

import anyio
from sqlalchemy import inspect as sqlalchemy_inspect, select
Expand Down Expand Up @@ -49,7 +49,13 @@ def __init__(self) -> None:

self.converters = converters

def get_converter(self, column: Column) -> Callable:
def get_converter(
self, prop: Union[ColumnProperty, RelationshipProperty]
) -> Callable:
if not isinstance(prop, ColumnProperty):
return self.converters[prop.direction.name]

column = prop.columns[0]
types = inspect.getmro(type(column.type))

# Search by module + name
Expand Down Expand Up @@ -79,16 +85,24 @@ async def convert(
mapper: Mapper,
prop: Union[ColumnProperty, RelationshipProperty],
engine: Union[Engine, AsyncEngine],
field_args: Dict[str, Any] = None,
label: Optional[str] = None,
override: Optional[Type[Field]] = None,
) -> UnboundField:
kwargs: Dict = {
"validators": [],
"filters": [],
"default": None,
"description": prop.doc,
"render_kw": {"class": "form-control"},
}

converter = None
if field_args:
kwargs = field_args.copy()
else:
kwargs = {}

kwargs: Dict[str, Any]
kwargs.setdefault("label", label)
kwargs.setdefault("validators", [])
kwargs.setdefault("filters", [])
kwargs.setdefault("default", None)
kwargs.setdefault("description", prop.doc)
kwargs.setdefault("render_kw", {"class": "form-control"})

converter = self.get_converter(prop)
column = None

if isinstance(prop, ColumnProperty):
Expand Down Expand Up @@ -119,8 +133,6 @@ async def convert(
kwargs["validators"].append(validators.Optional())
else:
kwargs["validators"].append(validators.InputRequired())

converter = self.get_converter(column)
else:
nullable = True
for pair in prop.local_remote_pairs:
Expand Down Expand Up @@ -150,9 +162,9 @@ async def convert(
]
kwargs["object_list"] = object_list

converter = self.converters[prop.direction.name]

assert converter is not None
if override is not None:
assert issubclass(override, Field)
return override(**kwargs)

return converter(
model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs
Expand Down Expand Up @@ -256,10 +268,17 @@ async def get_model_form(
engine: Union[Engine, AsyncEngine],
only: Sequence[str] = None,
exclude: Sequence[str] = None,
column_labels: Dict[str, str] = None,
form_args: Dict[str, Dict[str, Any]] = None,
form_class: Type[Form] = Form,
form_overrides: Dict[str, Dict[str, Type[Field]]] = None,
) -> Type[Form]:
type_name = model.__name__ + "Form"
converter = ModelConverter()
mapper = sqlalchemy_inspect(model)
form_args = form_args or {}
column_labels = column_labels or {}
form_overrides = form_overrides or {}

attributes = []
for name, attr in mapper.attrs.items():
Expand All @@ -272,8 +291,13 @@ async def get_model_form(

field_dict = {}
for name, attr in attributes:
field = await converter.convert(model, mapper, attr, engine)
field_args = form_args.get(name, {})
label = column_labels.get(name, None)
override = form_overrides.get(name, None)
field = await converter.convert(
model, mapper, attr, engine, field_args, label, override
)
if field is not None:
field_dict[name] = field

return type(type_name, (Form,), field_dict)
return type(type_name, (form_class,), field_dict)
165 changes: 140 additions & 25 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from enum import Enum
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Expand All @@ -25,7 +27,7 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ClauseElement
from starlette.requests import Request
from wtforms import Form
from wtforms import Field, Form

from sqladmin.exceptions import InvalidColumnError, InvalidModelError
from sqladmin.forms import get_model_form
Expand Down Expand Up @@ -71,6 +73,9 @@ def __new__(mcls, name, bases, attrs: dict, **kwargs: Any):
cls.icon = attrs.get("icon")

mcls._check_conflicting_options(["column_list", "column_exclude_list"], attrs)
mcls._check_conflicting_options(
["form_columns", "form_excluded_columns"], attrs
)
mcls._check_conflicting_options(
["column_details_list", "column_details_exclude_list"], attrs
)
Expand Down Expand Up @@ -283,6 +288,88 @@ class UserAdmin(ModelAdmin, model=User):
edit_template: ClassVar[str] = "edit.html"
"""Edit view template. Default is `edit.html`."""

# Form
form: ClassVar[Optional[Type[Form]]] = None
"""Form class.
Override if you want to use custom form for your model.
Will completely disable form scaffolding functionality.

???+ example
```python
class MyForm(Form):
name = StringField('Name')

class MyModelAdmin(ModelAdmin, model=User):
form = MyForm
```
"""

form_base_class: ClassVar[Type[Form]] = Form
"""Base form class.
Will be used by form scaffolding function when creating model form.
Useful if you want to have custom constructor or override some fields.

???+ example
```python
class MyBaseForm(Form):
def do_something(self):
pass

class MyModelAdmin(ModelAdmin, model=User):
form_base_class = MyBaseForm
```
"""

form_args: ClassVar[Dict[str, Dict[str, Any]]] = {}
"""Dictionary of form field arguments.
Refer to WTForms documentation for list of possible options.

???+ example
```python
from wtforms.validators import DataRequired

class MyModelAdmin(ModelAdmin, model=User):
form_args = dict(
name=dict(label="User Name", validators=[DataRequired()])
)
```
"""

form_columns: ClassVar[Sequence[Union[str, InstrumentedAttribute]]] = []
"""List of columns to include in the form.
Columns can either be string names or SQLAlchemy columns.

???+ note
By default all columns of Model are included in the form.

???+ example
```python
class UserAdmin(ModelAdmin, model=User):
form_columns = [User.name, User.mail]
```
"""

form_excluded_columns: ClassVar[Sequence[Union[str, InstrumentedAttribute]]] = []
"""List of columns to exclude from the form.
Columns can either be string names or SQLAlchemy columns.

???+ example
```python
class UserAdmin(ModelAdmin, model=User):
form_excluded_columns = [User.id]
```
"""

form_overrides: ClassVar[Dict[str, Type[Field]]] = {}
"""Dictionary of form column overrides.

???+ example
```python
class UserAdmin(ModelAdmin, model=User):
form_overrides = dict(name=wtf.FileField)
```
"""

def __init__(self) -> None:
self._column_labels = self.get_column_labels()

Expand Down Expand Up @@ -445,43 +532,61 @@ def get_model_attr(
def get_model_attributes(self) -> List[Column]:
return list(inspect(self.model).attrs)

def _build_column_list(
self,
include: Optional[Sequence[Union[str, InstrumentedAttribute]]] = None,
exclude: Optional[Sequence[Union[str, InstrumentedAttribute]]] = None,
default: Callable[[], List[Column]] = None,
) -> List[Tuple[str, Column]]:
"""This function generalizes constructing a list of columns
for any sequence of inclusions or exclusions.
"""
if include:
attrs = [self.get_model_attr(attr) for attr in include]
elif exclude:
exclude_columns = [self.get_model_attr(attr) for attr in exclude]
all_attrs = self.get_model_attributes()
attrs = list(set(all_attrs) - set(exclude_columns))
else:
attrs = default()

return [(self._column_labels.get(attr, attr.key), attr) for attr in attrs]

def get_list_columns(self) -> List[Tuple[str, Column]]:
"""Get list of columns to display in List page."""

column_list = getattr(self, "column_list", None)
column_exclude_list = getattr(self, "column_exclude_list", None)

if column_list:
attrs = [self.get_model_attr(attr) for attr in self.column_list]
elif column_exclude_list:
exclude_columns = [
self.get_model_attr(attr) for attr in column_exclude_list
]
all_attrs = self.get_model_attributes()
attrs = list(set(all_attrs) - set(exclude_columns))
else:
attrs = [getattr(self.model, self.pk_column.name).prop]

return [(self._column_labels.get(attr, attr.key), attr) for attr in attrs]
return self._build_column_list(
include=column_list,
exclude=column_exclude_list,
default=lambda: [getattr(self.model, self.pk_column.name).prop],
)

def get_details_columns(self) -> List[Tuple[str, Column]]:
"""Get list of columns to display in Detail page."""

column_details_list = getattr(self, "column_details_list", None)
column_details_exclude_list = getattr(self, "column_details_exclude_list", None)

if column_details_list:
attrs = [self.get_model_attr(attr) for attr in column_details_list]
elif column_details_exclude_list:
exclude_columns = [
self.get_model_attr(attr) for attr in column_details_exclude_list
]
all_attrs = self.get_model_attributes()
attrs = list(set(all_attrs) - set(exclude_columns))
else:
attrs = self.get_model_attributes()
return self._build_column_list(
include=column_details_list,
exclude=column_details_exclude_list,
default=self.get_model_attributes,
)

return [(self._column_labels.get(attr, attr.key), attr) for attr in attrs]
def get_form_columns(self) -> List[Tuple[str, Column]]:
"""Get list of columns to display in the form."""

form_columns = getattr(self, "form_columns", None)
form_excluded_columns = getattr(self, "form_excluded_columns", None)

return self._build_column_list(
include=form_columns,
exclude=form_excluded_columns,
default=self.get_model_attributes,
)

def get_column_labels(self) -> Dict[Column, str]:
return {
Expand Down Expand Up @@ -525,7 +630,17 @@ async def update_model(self, pk: Any, data: Dict[str, Any]) -> None:
await anyio.to_thread.run_sync(self._update_modeL_sync, pk, data)

async def scaffold_form(self) -> Type[Form]:
return await get_model_form(model=self.model, engine=self.engine)
if self.form is not None:
return self.form
return await get_model_form(
model=self.model,
engine=self.engine,
only=[i[1].key for i in self.get_form_columns()],
column_labels={k.key: v for k, v in self._column_labels.items()},
form_args=self.form_args,
form_class=self.form_base_class,
form_overrides=self.form_overrides,
)

def search_placeholder(self) -> str:
"""Return search placeholder text.
Expand Down
Loading