Skip to content

Commit

Permalink
Merge pull request #1 from mbsantiago/main
Browse files Browse the repository at this point in the history
Making checks pass
  • Loading branch information
AntonDeMeester authored Sep 18, 2023
2 parents 40bcdfe + 3005495 commit 244c947
Show file tree
Hide file tree
Showing 23 changed files with 580 additions and 168 deletions.
2 changes: 1 addition & 1 deletion docs_src/tutorial/fastapi/delete/tutorial001.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.dict(exclude_unset=True)
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.7"
SQLAlchemy = ">=2.0.0,<=2.0.11"
pydantic = "^2.1.1"
pydantic = { version = ">=2.1.1,<=2.4", extras = ["email"] }

[tool.poetry.dev-dependencies]
pytest = "^7.0.1"
Expand All @@ -52,6 +52,7 @@ autoflake = "^1.4"
isort = "^5.9.3"
async_generator = {version = "*", python = "~3.7"}
async-exit-stack = {version = "*", python = "~3.7"}
importlib-metadata = { version = "*", python = ">3.7" }
httpx = "^0.24.1"

[build-system]
Expand Down
11 changes: 9 additions & 2 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Type, TypeVar, Union

from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.orm import Mapper
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable

Expand All @@ -14,13 +16,18 @@
_T = TypeVar("_T")


BindsType = Dict[
Union[Type[Any], Mapper[Any], TableClause, str], Union[AsyncEngine, AsyncConnection]
]


class AsyncSession(_AsyncSession):
sync_session: Session

def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
binds: Optional[BindsType] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
Expand Down
141 changes: 122 additions & 19 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import ipaddress
import sys
import types
import uuid
import weakref
from datetime import date, datetime, time, timedelta
Expand All @@ -22,11 +26,11 @@
TypeVar,
Union,
cast,
get_args,
get_origin,
)

from pydantic import BaseModel
import pydantic
from annotated_types import MaxLen
from pydantic import BaseModel, EmailStr, ImportString, NameEmail
from pydantic._internal._fields import PydanticGeneralMetadata
from pydantic._internal._model_construction import ModelMetaclass
from pydantic._internal._repr import Representation
Expand All @@ -39,12 +43,21 @@
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.sql import false, true
from sqlalchemy.sql.schema import DefaultClause, MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time

from .sql.sqltypes import GUID, AutoString
from .typing import SQLModelConfig

if sys.version_info >= (3, 8):
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin

from typing_extensions import Annotated, _AnnotatedAlias

_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
NoneType = type(None)
Expand All @@ -61,6 +74,8 @@ def __dataclass_transform__(


class FieldInfo(PydanticFieldInfo):
nullable: Union[bool, PydanticUndefinedType]

def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", PydanticUndefined)
Expand Down Expand Up @@ -150,14 +165,40 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
] = PydanticUndefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
if default is PydanticUndefined:
if isinstance(sa_column, types.FunctionType): # lambda
sa_column_ = sa_column()
else:
sa_column_ = sa_column

# server_default -> default
if isinstance(sa_column_, Column) and isinstance(
sa_column_.server_default, DefaultClause
):
default_value = sa_column_.server_default.arg
if issubclass(type(sa_column_.type), Integer) and isinstance(
default_value, str
):
default = int(default_value)
elif issubclass(type(sa_column_.type), Boolean):
if default_value is false():
default = False
elif default_value is true():
default = True
elif isinstance(default_value, str):
if default_value == "1":
default = True
elif default_value == "0":
default = False

field_info = FieldInfo(
default,
default_factory=default_factory,
Expand Down Expand Up @@ -236,7 +277,6 @@ def __new__(
class_dict: Dict[str, Any],
**kwargs: Any,
) -> Any:

relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = class_dict.get("__annotations__", {})
Expand Down Expand Up @@ -398,23 +438,50 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)


def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union


def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: type | None = field.annotation
type_: Optional[type] | _AnnotatedAlias = field.annotation

# Resolve Optional/Union fields

# Resolve Optional fields
if type_ is not None and get_origin(type_) is Union:
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
"Cannot have a (non-optional) union as a SQL alchemy field"
)
type_ = bases[0]
# Resolve Annoted fields,
# like typing.Annotated[pydantic_core._pydantic_core.Url,
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
if field.metadata:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
else:
return AutoString

org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)

# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
if type_ is None:
raise ValueError("Missing field type")
if issubclass(type_, str):
if issubclass(type_, str) or type_ in (EmailStr, NameEmail, ImportString):
max_length = getattr(metadata, "max_length", None)
if max_length:
return AutoString(length=max_length)
Expand Down Expand Up @@ -458,9 +525,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:


def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
Expand All @@ -484,7 +560,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"index": index,
"unique": unique,
}
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
if field.default_factory:
sa_default = field.default_factory
elif field.default is not PydanticUndefined:
Expand Down Expand Up @@ -524,12 +600,16 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object

def __init__(__pydantic_self__, **data: Any) -> None:
old_dict = __pydantic_self__.__dict__.copy()
super().__init__(**data)
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
for key in non_pydantic_keys:
if key in __pydantic_self__.__sqlmodel_relationships__:
Expand Down Expand Up @@ -558,33 +638,54 @@ def __tablename__(cls) -> str:

@classmethod
def model_validate(
cls: type[_TSQLModel],
cls: Type[_TSQLModel],
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
strict: Optional[bool] = None,
from_attributes: Optional[bool] = None,
context: Optional[Dict[str, Any]] = None,
) -> _TSQLModel:
# Somehow model validate doesn't call __init__ so it would remove our init logic
validated = super().model_validate(
obj, strict=strict, from_attributes=from_attributes, context=context
)
return cls(**{key: value for key, value in validated})

# remove defaults so they don't get validated
data = {}
for key, value in validated:
field = cls.model_fields.get(key)

if field is None:
continue

if (
hasattr(field, "default")
and field.default is not PydanticUndefined
and value == field.default
):
continue

data[key] = value

return cls(**data)


def _is_field_noneable(field: FieldInfo) -> bool:
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
if hasattr(field, "nullable") and not isinstance(
field.nullable, PydanticUndefinedType
):
return field.nullable
if not field.is_required():
default = getattr(field, "original_default", field.default)
if default is PydanticUndefined:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True

return False
return False

Expand All @@ -593,4 +694,6 @@ def _get_field_metadata(field: FieldInfo) -> object:
for meta in field.metadata:
if isinstance(meta, PydanticGeneralMetadata):
return meta
if isinstance(meta, MaxLen):
return meta
return object()
Loading

0 comments on commit 244c947

Please sign in to comment.