diff --git a/.github/actions/comment-docs-preview-in-pr/app/main.py b/.github/actions/comment-docs-preview-in-pr/app/main.py index 3b10e0ee08..c9fb7cbbef 100644 --- a/.github/actions/comment-docs-preview-in-pr/app/main.py +++ b/.github/actions/comment-docs-preview-in-pr/app/main.py @@ -48,9 +48,7 @@ class PartialGithubEvent(BaseModel): use_pr = pr break if not use_pr: - logging.error( - f"No PR found for hash: {event.workflow_run.head_commit.id}" - ) + logging.error(f"No PR found for hash: {event.workflow_run.head_commit.id}") sys.exit(0) github_headers = { "Authorization": f"token {settings.input_token.get_secret_value()}" diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py index 97481259e2..b2d567b1b1 100644 --- a/sqlmodel/engine/create.py +++ b/sqlmodel/engine/create.py @@ -136,4 +136,4 @@ def create_engine( if not isinstance(query_cache_size, _DefaultPlaceholder): current_kwargs["query_cache_size"] = query_cache_size current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) + return _create_engine(url, **current_kwargs) # type: ignore diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 17020d9995..1d6b40ea13 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Sequence, TypeVar +from typing import Any, Generic, Iterator, List, Optional, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -6,23 +6,23 @@ _T = TypeVar("_T") -class ScalarResult(_ScalarResult[_T], Generic[_T]): - def all(self) -> Sequence[_T]: +class ScalarResult(_ScalarResult, Generic[_T]): + def all(self) -> List[_T]: return super().all() - def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]: + def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]: return super().partitions(size) - def fetchall(self) -> Sequence[_T]: + def fetchall(self) -> List[_T]: return super().fetchall() - def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]: + def fetchmany(self, size: Optional[int] = None) -> List[_T]: return super().fetchmany(size) def __iter__(self) -> Iterator[_T]: return super().__iter__() - def __next__(self) -> _T: + def __next__(self) -> Any: return super().__next__() def first(self) -> Optional[_T]: @@ -31,11 +31,11 @@ def first(self) -> Optional[_T]: def one_or_none(self) -> Optional[_T]: return super().one_or_none() - def one(self) -> _T: + def one(self) -> Any: return super().one() -class Result(_Result[_T], Generic[_T]): +class Result(_Result, Generic[_T]): def scalars(self, index: int = 0) -> ScalarResult[_T]: return super().scalars(index) # type: ignore @@ -45,8 +45,8 @@ def __iter__(self) -> Iterator[_T]: # type: ignore def __next__(self) -> _T: # type: ignore return super().__next__() # type: ignore - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore - return super().partitions(size) # type: ignore + def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]: + return super().partitions(size) def fetchall(self) -> List[_T]: # type: ignore return super().fetchall() # type: ignore @@ -55,7 +55,7 @@ def fetchone(self) -> Optional[_T]: # type: ignore return super().fetchone() # type: ignore def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore - return super().fetchmany() # type: ignore + return super().fetchmany(size) # type: ignore def all(self) -> List[_T]: # type: ignore return super().all() # type: ignore @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() # type: ignore + return super().scalar() diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 658e5384d8..55293bf30f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -29,7 +29,7 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import ModelMetaclass, validate_model -from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations +from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations # type: ignore from pydantic.utils import ROOT_KEY, Representation from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Enum as sa_Enum @@ -522,7 +522,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore + if getattr(self.__config__, "table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 64f6ad7967..1692fdcbcb 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -118,7 +118,7 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) # type: ignore + return super().query(*entities, **kwargs) def get( self, diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index dd31897b1d..f473ba4a5a 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -22,14 +22,16 @@ _TSelect = TypeVar("_TSelect") -class Select(_Select[_TSelect], Generic[_TSelect]): + +class Select(_Select, Generic[_TSelect]): inherit_cache = True + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]): +class SelectOfScalar(_Select, Generic[_TSelect]): inherit_cache = True diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index da6551b790..3a6d4e964d 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast from sqlalchemy import CHAR, types from sqlalchemy.dialects.postgresql import UUID @@ -8,15 +8,14 @@ class AutoString(types.TypeDecorator): # type: ignore - impl = types.String cache_ok = True mysql_default_length = 255 - def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) + return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore[arg-type, no-any-return] return super().load_dialect_impl(dialect) @@ -33,13 +32,15 @@ class GUID(types.TypeDecorator): # type: ignore impl = CHAR cache_ok = True - def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore + def load_dialect_impl(self, dialect: Dialect) -> "TypeEngine[Any]": if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(UUID()) # type: ignore[arg-type, no-any-return] else: - return dialect.type_descriptor(CHAR(32)) + return dialect.type_descriptor(CHAR(32)) # type: ignore[arg-type, no-any-return] - def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: + def process_bind_param( + self, value: Optional[Union[str, uuid.UUID]], dialect: Dialect + ) -> Optional[str]: if value is None: return value elif dialect.name == "postgresql": @@ -51,10 +52,12 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: # hexstring return value.hex - def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]: + def process_result_value( + self, value: Optional[Union[str, uuid.UUID]], dialect: Dialect + ) -> Optional[uuid.UUID]: if value is None: return value - else: - if not isinstance(value, uuid.UUID): - value = uuid.UUID(value) - return cast(uuid.UUID, value) + + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return value