diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 17020d9995..2e68184660 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -69,8 +69,8 @@ def one_or_none(self) -> Optional[_T]: # type: ignore def scalar_one(self) -> _T: return super().scalar_one() # type: ignore - def scalar_one_or_none(self) -> Optional[_T]: - return super().scalar_one_or_none() + def scalar_one_or_none(self) -> Optional[_T]: # type: ignore + return super().scalar_one_or_none() # type: ignore def one(self) -> _T: # type: ignore return super().one() # type: ignore diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a9a8620dfd..79a233c60d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,6 +14,7 @@ ForwardRef, List, Mapping, + NoneType, Optional, Sequence, Set, @@ -194,7 +195,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -211,7 +212,7 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - model_config: Type[SQLModelConfig] + model_config: SQLModelConfig model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy @@ -280,7 +281,9 @@ def __new__( if dict_used.get(key, PydanticUndefined) is PydanticUndefined: dict_used[key] = None - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls: Type["SQLModelMetaclass"] = super().__new__( + cls, name, bases, dict_used, **config_kwargs + ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -371,7 +374,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty[Any] = relationship( relationship_to, *rel_args, **rel_kwargs ) dict_used[rel_name] = rel_value @@ -382,7 +385,8 @@ def __init__( def get_sqlalchemy_type(field: FieldInfo) -> Any: - type_ = field.annotation + type_: type | None = field.annotation + # Resolve Optional fields if type_ is not None and get_origin(type_) is Union: bases = get_args(type_) @@ -394,9 +398,12 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any: # The 3rd is PydanticGeneralMetadata metadata = _get_field_metadata(field) + if type_ is None: + raise ValueError("Missing field type") if issubclass(type_, str): - if getattr(metadata, "max_length", None): - return AutoString(length=metadata.max_length) + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) return AutoString if issubclass(type_, float): return Float @@ -463,7 +470,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = PydanticUndefined + sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined if field.default_factory: sa_default = field.default_factory elif field.default is not PydanticUndefined: @@ -483,14 +490,12 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore default_registry = registry() -_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") - class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six @@ -511,7 +516,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if self.model_config.get("table", False) and is_instrumented(self, name): + if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values @@ -529,11 +534,11 @@ def __tablename__(cls) -> str: def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required(): - if field.annotation is None or field.annotation is type(None): + if field.annotation is None or field.annotation is NoneType: return True if get_origin(field.annotation) is Union: for base in get_args(field.annotation): - if base is type(None): + if base is NoneType: return True return False return False diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py index f2c87503c0..8151f99692 100644 --- a/sqlmodel/typing.py +++ b/sqlmodel/typing.py @@ -3,7 +3,7 @@ from pydantic import ConfigDict -class SQLModelConfig(ConfigDict): +class SQLModelConfig(ConfigDict, total=False): table: Optional[bool] read_from_attributes: Optional[bool] registry: Optional[Any]