Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_column_from_field support functional sa_column #2

Merged
merged 30 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3e7811
get_column_from_field support function
honglei Aug 15, 2023
9e07c1c
fix type check for sa_column
honglei Aug 15, 2023
8cc628c
Merge branch 'mbsantiago:main' into main
honglei Aug 15, 2023
5b49f77
get_column_from_field:sa_column>field attribute>field annotation
honglei Aug 16, 2023
650534e
Merge branch 'main' of https://github.com/honglei/sqlmodel
honglei Aug 16, 2023
e919dc8
Merge branch 'mbsantiago:main' into main
honglei Aug 16, 2023
1c626cb
Merge branch 'main' of https://github.com/honglei/sqlmodel
honglei Aug 16, 2023
6a5f373
Revert "get_column_from_field:sa_column>field attribute>field annotat…
honglei Aug 16, 2023
72dc89d
field is required by default, while nullable=True for Column
honglei Aug 16, 2023
fa8955c
field required
honglei Aug 16, 2023
045f9bc
add test for pydantic.AnyURL
honglei Aug 16, 2023
6e89ad3
black/isort for test_nullable.py
honglei Aug 16, 2023
6b7925d
fix _is_field_noneable
honglei Aug 20, 2023
4e89361
add test for Class hierarchy
honglei Aug 20, 2023
7752780
fix isort
honglei Aug 20, 2023
499bc18
add reason for skipif
honglei Aug 20, 2023
8e2b363
annotation not null
honglei Aug 20, 2023
f5fd850
fix model_copy
50Bytes-dev Aug 24, 2023
45ab472
fix
50Bytes-dev Aug 24, 2023
006ec7d
Merge pull request #1 from 50Bytes-dev/main
honglei Aug 25, 2023
f266da7
black/isort test_model_copy.py
honglei Aug 25, 2023
ef56d08
remove unused import in test_model_copy.py
honglei Aug 25, 2023
e0d32fb
try fix py3.8/test_nullable.py
honglei Aug 25, 2023
aa3325b
ugly way to fix py3.8/Annotation
honglei Aug 25, 2023
e824730
miss import _AnnotatedAlias
honglei Aug 25, 2023
dcb406f
fix py3.9+ _AnnotatedAlias
honglei Aug 25, 2023
d13fb74
only use typing_extensions to import _AnnotatedAlias
honglei Aug 25, 2023
c02b579
support AnyURL
honglei Aug 25, 2023
cb6ccf4
forgot black it
honglei Aug 25, 2023
4213c97
support AnyURL
honglei Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@
else:
from typing_extensions import get_args, get_origin

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
from typing_extensions import Annotated, _AnnotatedAlias

_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
Expand Down Expand Up @@ -167,7 +164,7 @@ def Field(
unique: bool = False,
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
sa_column_kwargs: Union[
Mapping[str, Any], PydanticUndefinedType
Expand Down Expand Up @@ -440,17 +437,19 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)


def _is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union


def get_sqlalchemy_type(field: FieldInfo) -> Any:
type_: Optional[type] = field.annotation
type_: Optional[type] | _AnnotatedAlias = field.annotation

# Resolve Optional/Union fields
def is_optional_or_union(type_: Optional[type]) -> bool:
if sys.version_info >= (3, 10):
return get_origin(type_) in (types.UnionType, Union)
else:
return get_origin(type_) is Union

if type_ is not None and is_optional_or_union(type_):
if type_ is not None and _is_optional_or_union(type_):
bases = get_args(type_)
if len(bases) > 2:
raise RuntimeError(
Expand All @@ -462,14 +461,20 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
# UrlConstraints(max_length=512,
# allowed_schemes=['smb', 'ftp', 'file']) ]
if type_ is pydantic.AnyUrl:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
if field.metadata:
meta = field.metadata[0]
return AutoString(length=meta.max_length)
else:
return AutoString

if get_origin(type_) is Annotated:
org_type = get_origin(type_)
if org_type is Annotated:
type2 = get_args(type_)[0]
if type2 is pydantic.AnyUrl:
meta = get_args(type_)[1]
return AutoString(length=meta.max_length)
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
return AutoString(type_.__metadata__[0].max_length)

# The 3rd is PydanticGeneralMetadata
metadata = _get_field_metadata(field)
Expand Down Expand Up @@ -519,11 +524,18 @@ def is_optional_or_union(type_: Optional[type]) -> bool:


def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
"""
sa_column > field attributes > annotation info
"""
sa_column = getattr(field, "sa_column", PydanticUndefined)
if isinstance(sa_column, Column):
return sa_column
if isinstance(sa_column, MappedColumn):
return sa_column.column
if isinstance(sa_column, types.FunctionType):
col = sa_column()
assert isinstance(col, Column)
return col
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field, "primary_key", False)
index = getattr(field, "index", PydanticUndefined)
Expand Down Expand Up @@ -587,6 +599,10 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# in the Pydantic model so that when SQLAlchemy sets attributes that are
# added (e.g. when querying from DB) to the __fields_set__, this already exists
object.__setattr__(new_object, "__pydantic_fields_set__", set())
if not hasattr(new_object, "__pydantic_extra__"):
object.__setattr__(new_object, "__pydantic_extra__", None)
if not hasattr(new_object, "__pydantic_private__"):
object.__setattr__(new_object, "__pydantic_private__", None)
return new_object

def __init__(__pydantic_self__, **data: Any) -> None:
Expand Down Expand Up @@ -636,7 +652,10 @@ def model_validate(
# remove defaults so they don't get validated
data = {}
for key, value in validated:
field = cls.model_fields[key]
field = cls.model_fields.get(key)

if field is None:
continue

if (
hasattr(field, "default")
Expand All @@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
return False
if field.annotation is None or field.annotation is NoneType:
return True
if get_origin(field.annotation) is Union:
if _is_optional_or_union(field.annotation):
for base in get_args(field.annotation):
if base is NoneType:
return True

return False
return False

Expand Down
1 change: 0 additions & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class AutoString(types.TypeDecorator): # type: ignore

impl = types.String
cache_ok = True
mysql_default_length = 255
Expand Down
78 changes: 78 additions & 0 deletions tests/test_class_hierarchy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import datetime
import sys

import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
Integer,
SQLModel,
String,
create_engine,
)
from typing_extensions import Annotated

MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_field_resuse():
class BasicFileLog(SQLModel):
resourceID: int = Field(
sa_column=lambda: Column(Integer, index=True), description=""" """
)
transportID: Annotated[int | None, Field(description=" for ")] = None
fileName: str = Field(
sa_column=lambda: Column(String, index=True), description=""" """
)
fileSize: int | None = Field(
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
)
beginTime: datetime.datetime | None = Field(
sa_column=lambda: Column(
DateTime(timezone=True),
index=True,
),
description="",
)

class SendFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
sendUser: str
dstUrl: MoveSharedUrl | None

class RecvFileLog(BasicFileLog, table=True):
id: int | None = Field(
sa_column=Column(Integer, primary_key=True, autoincrement=True),
description=""" """,
)
recvUser: str

sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url, echo=True)
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
SendFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
RecvFileLog(
sendUser="j",
resourceID=1,
fileName="a.txt",
fileSize=3234,
beginTime=datetime.datetime.now(),
)
50 changes: 50 additions & 0 deletions tests/test_model_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional

from sqlmodel import Field, Session, SQLModel, create_engine


def test_model_copy(clear_sqlmodel):
"""Test validation of implicit and explict None values.

# For consistency with pydantic, validators are not to be called on
# arguments that are not explicitly provided.

https://github.com/tiangolo/sqlmodel/issues/230
https://github.com/samuelcolvin/pydantic/issues/1223

"""

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None

hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(hero)
session.commit()
session.refresh(hero)

model_copy = hero.model_copy(update={"name": "Deadpond Copy"})

assert (
model_copy.name == "Deadpond Copy"
and model_copy.secret_name == "Dive Wilson"
and model_copy.age == 25
)

db_hero = session.get(Hero, hero.id)

db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})

assert (
db_copy.name == "Deadpond Copy"
and db_copy.secret_name == "Dive Wilson"
and db_copy.age == 25
)
22 changes: 22 additions & 0 deletions tests/test_nullable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Optional

import pytest
from pydantic import AnyUrl, UrlConstraints
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Session, SQLModel, create_engine
from typing_extensions import Annotated

MoveSharedUrl = Annotated[
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
]


def test_nullable_fields(clear_sqlmodel, caplog):
Expand All @@ -13,6 +19,8 @@ class Hero(SQLModel, table=True):
)
required_value: str
optional_default_ellipsis: Optional[str] = Field(default=...)
optional_no_field: Optional[str]
optional_no_field_default: Optional[str] = Field(description="no default")
optional_default_none: Optional[str] = Field(default=None)
optional_non_nullable: Optional[str] = Field(
nullable=False,
Expand Down Expand Up @@ -49,6 +57,13 @@ class Hero(SQLModel, table=True):
str_default_str_nullable: str = Field(default="default", nullable=True)
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
base_url: AnyUrl
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
url: MoveSharedUrl
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
annotated_optional_url: Annotated[
Optional[MoveSharedUrl], Field(description="")
] = None

engine = create_engine("sqlite://", echo=True)
SQLModel.metadata.create_all(engine)
Expand All @@ -59,6 +74,8 @@ class Hero(SQLModel, table=True):
assert "primary_key INTEGER NOT NULL," in create_table_log
assert "required_value VARCHAR NOT NULL," in create_table_log
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
assert "optional_no_field VARCHAR," in create_table_log
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
assert "optional_default_none VARCHAR," in create_table_log
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
assert "optional_nullable VARCHAR," in create_table_log
Expand All @@ -77,6 +94,11 @@ class Hero(SQLModel, table=True):
assert "str_default_str_nullable VARCHAR," in create_table_log
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
assert "base_url VARCHAR NOT NULL," in create_table_log
assert "optional_url VARCHAR(512), " in create_table_log
assert "url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
assert "annotated_optional_url VARCHAR(512)," in create_table_log


# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420
Expand Down
Loading