Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making checks pass #1

Merged
merged 48 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
972ee56
Fix fastapi tests
mbsantiago Aug 9, 2023
179183c
Fix linting issues
mbsantiago Aug 9, 2023
c99c1a9
Make sure tests pass in all supported python versions
mbsantiago Aug 9, 2023
347e052
support str|None , mapped_column, AnyURL
honglei Aug 12, 2023
9da0407
support 3.9/use get_args
honglei Aug 14, 2023
a22dac8
avoid get_args directly
honglei Aug 14, 2023
ce0064b
python version for types.UnionType/Annotated
honglei Aug 14, 2023
a0b84c5
avoid compare types:FunctionType
honglei Aug 14, 2023
fa8902c
add type hints for func is_optional_or_union
honglei Aug 14, 2023
710e92b
black it
honglei Aug 14, 2023
46b130d
fix isort error for `import types`
honglei Aug 14, 2023
f67b414
Merge pull request #1 from honglei/main
mbsantiago Aug 14, 2023
f3e7811
get_column_from_field support function
honglei Aug 15, 2023
9e07c1c
fix type check for sa_column
honglei Aug 15, 2023
8cc628c
Merge branch 'mbsantiago:main' into main
honglei Aug 15, 2023
63e2692
Add importlib-metadata to dev-dependencies
mbsantiago Aug 16, 2023
5b49f77
get_column_from_field:sa_column>field attribute>field annotation
honglei Aug 16, 2023
650534e
Merge branch 'main' of https://github.com/honglei/sqlmodel
honglei Aug 16, 2023
e919dc8
Merge branch 'mbsantiago:main' into main
honglei Aug 16, 2023
1c626cb
Merge branch 'main' of https://github.com/honglei/sqlmodel
honglei Aug 16, 2023
6a5f373
Revert "get_column_from_field:sa_column>field attribute>field annotat…
honglei Aug 16, 2023
72dc89d
field is required by default, while nullable=True for Column
honglei Aug 16, 2023
fa8955c
field required
honglei Aug 16, 2023
045f9bc
add test for pydantic.AnyURL
honglei Aug 16, 2023
6e89ad3
black/isort for test_nullable.py
honglei Aug 16, 2023
6b7925d
fix _is_field_noneable
honglei Aug 20, 2023
4e89361
add test for Class hierarchy
honglei Aug 20, 2023
7752780
fix isort
honglei Aug 20, 2023
499bc18
add reason for skipif
honglei Aug 20, 2023
8e2b363
annotation not null
honglei Aug 20, 2023
f5fd850
fix model_copy
50Bytes-dev Aug 24, 2023
45ab472
fix
50Bytes-dev Aug 24, 2023
006ec7d
Merge pull request #1 from 50Bytes-dev/main
honglei Aug 25, 2023
f266da7
black/isort test_model_copy.py
honglei Aug 25, 2023
ef56d08
remove unused import in test_model_copy.py
honglei Aug 25, 2023
e0d32fb
try fix py3.8/test_nullable.py
honglei Aug 25, 2023
aa3325b
ugly way to fix py3.8/Annotation
honglei Aug 25, 2023
e824730
miss import _AnnotatedAlias
honglei Aug 25, 2023
dcb406f
fix py3.9+ _AnnotatedAlias
honglei Aug 25, 2023
d13fb74
only use typing_extensions to import _AnnotatedAlias
honglei Aug 25, 2023
c02b579
support AnyURL
honglei Aug 25, 2023
cb6ccf4
forgot black it
honglei Aug 25, 2023
4213c97
support AnyURL
honglei Aug 25, 2023
bcb6f32
Merge pull request #2 from honglei/main
mbsantiago Aug 25, 2023
8b92179
support EmailStr
honglei Sep 4, 2023
ff8ed0b
Update pyproject.toml to support email
honglei Sep 4, 2023
80bc2a1
Merge branch 'mbsantiago:main' into main
honglei Sep 4, 2023
3005495
Merge pull request #3 from honglei/main
mbsantiago Sep 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs_src/tutorial/fastapi/delete/tutorial001.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.dict(exclude_unset=True)
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.7"
SQLAlchemy = ">=2.0.0,<=2.0.11"
pydantic = "^2.1.1"
pydantic = { version = ">=2.1.1,<=2.4", extras = ["email"] }

[tool.poetry.dev-dependencies]
pytest = "^7.0.1"
Expand All @@ -52,6 +52,7 @@ autoflake = "^1.4"
isort = "^5.9.3"
async_generator = {version = "*", python = "~3.7"}
async-exit-stack = {version = "*", python = "~3.7"}
importlib-metadata = { version = "*", python = ">3.7" }
httpx = "^0.24.1"

[build-system]
Expand Down
11 changes: 9 additions & 2 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Type, TypeVar, Union

from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.orm import Mapper
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable

Expand All @@ -14,13 +16,18 @@
_T = TypeVar("_T")


BindsType = Dict[
Union[Type[Any], Mapper[Any], TableClause, str], Union[AsyncEngine, AsyncConnection]
]


class AsyncSession(_AsyncSession):
sync_session: Session

def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
binds: Optional[BindsType] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
Expand Down
141 changes: 122 additions & 19 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import ipaddress
import sys
import types
import uuid
import weakref
from datetime import date, datetime, time, timedelta
Expand All @@ -22,11 +26,11 @@
TypeVar,
Union,
cast,
get_args,
get_origin,
)

from pydantic import BaseModel
import pydantic
from annotated_types import MaxLen
from pydantic import BaseModel, EmailStr, ImportString, NameEmail
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic._internal._model_construction import ModelMetaclass
from pydantic._internal._repr import Representation
Expand All @@ -39,12 +43,21 @@
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql import false, true
from sqlalchemy.sql.schema import DefaultClause, MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time

from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig

if sys.version_info >= (3, 8):
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin

from typing_extensions import Annotated, _AnnotatedAlias

_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
NoneType = type(None)
Expand All @@ -61,6 +74,8 @@ def __dataclass_transform__(


class FieldInfo(PydanticFieldInfo):
nullable: Union[bool, PydanticUndefinedType]

def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", PydanticUndefined)
Expand Down Expand Up @@ -150,14 +165,40 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
if default is PydanticUndefined:
if isinstance(sa_column, types.FunctionType): # lambda
sa_column_ = sa_column()
else:
sa_column_ = sa_column

# server_default -> default
if isinstance(sa_column_, Column) and isinstance(
sa_column_.server_default, DefaultClause
):
default_value = sa_column_.server_default.arg
if issubclass(type(sa_column_.type), Integer) and isinstance(
default_value, str
):
default = int(default_value)
elif issubclass(type(sa_column_.type), Boolean):
if default_value is false():
default = False
elif default_value is true():
default = True
elif isinstance(default_value, str):
if default_value == "1":
default = True
elif default_value == "0":
default = False

field_info = FieldInfo(
default,
default_factory=default_factory,
Expand Down Expand Up @@ -236,7 +277,6 @@ def __new__(
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:

relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = class_dict.get("__annotations__", {})
Expand Down Expand Up @@ -398,23 +438,50 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)


def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union


def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: type | None = field.annotation
type_: Optional[type] | _AnnotatedAlias = field.annotation

# Resolve Optional/Union fields

# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
"Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# Resolve Annoted fields,
# like typing.Annotated[pydantic_core._pydantic_core.Url,
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
if field.metadata:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
else:
return AutoString

org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)

# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
if type_ is None:
raise ValueError("Missing field type")
if issubclass(type_, str):
if issubclass(type_, str) or type_ in (EmailStr, NameEmail, ImportString):
max_length = getattr(metadata, "max_length", None)
if max_length:
return AutoString(length=max_length)
Expand Down Expand Up @@ -458,9 +525,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:


def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
Expand All @@ -484,7 +560,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"index": index,
"unique": unique,
}
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
if field.default_factory:
sa_default = field.default_factory
elif field.default is not PydanticUndefined:
Expand Down Expand Up @@ -524,12 +600,16 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object

def __init__(__pydantic_self__, **data: Any) -> None:
old_dict = __pydantic_self__.__dict__.copy()
super().__init__(**data)
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__:
Expand Down Expand Up @@ -558,33 +638,54 @@ def __tablename__(cls) -> str:

@classmethod
def model_validate(
cls: type[_TSQLModel],
cls: Type[_TSQLModel],
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
strict: Optional[bool] = None,
from_attributes: Optional[bool] = None,
context: Optional[Dict[str, Any]] = None,
) -> _TSQLModel:
# Somehow model validate doesn't call __init__ so it would remove our init logic
validated = super().model_validate(
obj, strict=strict, from_attributes=from_attributes, context=context
)
return cls(**{key: value for key, value in validated})

# remove defaults so they don't get validated
data = {}
for key, value in validated:
field = cls.model_fields.get(key)

if field is None:
continue

if (
hasattr(field, "default")
and field.default is not PydanticUndefined
and value == field.default
Comment on lines +662 to +664
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We. need to figure out default_factory as well I think.

Suggested change
hasattr(field, "default")
and field.default is not PydanticUndefined
and value == field.default
hasattr(field, "default")
and field.default is not PydanticUndefined
and field.default_factory is not None
and value == field.default

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default factory is a bit trickier since it is hard to know if the passed value is the output of the default factory or a user provided value. I'm not sure if Pydantic does run field validations in the case it was generated by the default factor, but if it doesn't then we need to avoid SQLModel from running the validation.

):
continue

data[key] = value

return cls(**data)


def _is_field_noneable(field: FieldInfo) -> bool:
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
if hasattr(field, "nullable") and not isinstance(
field.nullable, PydanticUndefinedType
):
return field.nullable
if not field.is_required():
default = getattr(field, "original_default", field.default)
if default is PydanticUndefined:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True

return False
return False

Expand All @@ -593,4 +694,6 @@ def _get_field_metadata(field: FieldInfo) -> object:
for meta in field.metadata:
if isinstance(meta, PydanticGeneralMetadata):
return meta
if isinstance(meta, MaxLen):
return meta
return object()
Loading
Loading