Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy linting #1

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/actions/comment-docs-preview-in-pr/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ class PartialGithubEvent(BaseModel):
use_pr = pr
break
if not use_pr:
logging.error(
f"No PR found for hash: {event.workflow_run.head_commit.id}"
)
logging.error(f"No PR found for hash: {event.workflow_run.head_commit.id}")
sys.exit(0)
github_headers = {
"Authorization": f"token {settings.input_token.get_secret_value()}"
Expand Down
2 changes: 1 addition & 1 deletion sqlmodel/engine/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,4 @@ def create_engine(
if not isinstance(query_cache_size, _DefaultPlaceholder):
current_kwargs["query_cache_size"] = query_cache_size
current_kwargs.update(kwargs)
return _create_engine(url, **current_kwargs)
return _create_engine(url, **current_kwargs) # type: ignore
26 changes: 13 additions & 13 deletions sqlmodel/engine/result.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from typing import Generic, Iterator, List, Optional, Sequence, TypeVar
from typing import Any, Generic, Iterator, List, Optional, TypeVar

from sqlalchemy.engine.result import Result as _Result
from sqlalchemy.engine.result import ScalarResult as _ScalarResult

_T = TypeVar("_T")


class ScalarResult(_ScalarResult[_T], Generic[_T]):
def all(self) -> Sequence[_T]:
class ScalarResult(_ScalarResult, Generic[_T]):
def all(self) -> List[_T]:
return super().all()

def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]:
def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]:
return super().partitions(size)

def fetchall(self) -> Sequence[_T]:
def fetchall(self) -> List[_T]:
return super().fetchall()

def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]:
def fetchmany(self, size: Optional[int] = None) -> List[_T]:
return super().fetchmany(size)

def __iter__(self) -> Iterator[_T]:
return super().__iter__()

def __next__(self) -> _T:
def __next__(self) -> Any:
return super().__next__()

def first(self) -> Optional[_T]:
Expand All @@ -31,11 +31,11 @@ def first(self) -> Optional[_T]:
def one_or_none(self) -> Optional[_T]:
return super().one_or_none()

def one(self) -> _T:
def one(self) -> Any:
return super().one()


class Result(_Result[_T], Generic[_T]):
class Result(_Result, Generic[_T]):
def scalars(self, index: int = 0) -> ScalarResult[_T]:
return super().scalars(index) # type: ignore

Expand All @@ -45,8 +45,8 @@ def __iter__(self) -> Iterator[_T]: # type: ignore
def __next__(self) -> _T: # type: ignore
return super().__next__() # type: ignore

def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore
return super().partitions(size) # type: ignore
def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]:
return super().partitions(size)

def fetchall(self) -> List[_T]: # type: ignore
return super().fetchall() # type: ignore
Expand All @@ -55,7 +55,7 @@ def fetchone(self) -> Optional[_T]: # type: ignore
return super().fetchone() # type: ignore

def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore
return super().fetchmany() # type: ignore
return super().fetchmany(size) # type: ignore

def all(self) -> List[_T]: # type: ignore
return super().all() # type: ignore
Expand All @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore
return super().one() # type: ignore

def scalar(self) -> Optional[_T]:
return super().scalar() # type: ignore
return super().scalar()
4 changes: 2 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass, validate_model
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations # type: ignore
from pydantic.utils import ROOT_KEY, Representation
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
Expand Down Expand Up @@ -522,7 +522,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore
if getattr(self.__config__, "table", False) and is_instrumented(self, name):
set_attribute(self, name, value)
# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
Expand Down
2 changes: 1 addition & 1 deletion sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
Or otherwise you might want to use `session.execute()` instead of
`session.query()`.
"""
return super().query(*entities, **kwargs) # type: ignore
return super().query(*entities, **kwargs)

def get(
self,
Expand Down
6 changes: 4 additions & 2 deletions sqlmodel/sql/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@

_TSelect = TypeVar("_TSelect")

class Select(_Select[_TSelect], Generic[_TSelect]):

class Select(_Select, Generic[_TSelect]):
inherit_cache = True


# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
# purpose. This is the same as a normal SQLAlchemy Select class where there's only one
# entity, so the result will be converted to a scalar by default. This way writing
# for loops on the results will feel natural.
class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]):
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True


Expand Down
29 changes: 16 additions & 13 deletions sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast

from sqlalchemy import CHAR, types
from sqlalchemy.dialects.postgresql import UUID
Expand All @@ -8,15 +8,14 @@


class AutoString(types.TypeDecorator): # type: ignore

impl = types.String
cache_ok = True
mysql_default_length = 255

def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
impl = cast(types.String, self.impl)
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length))
return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore[arg-type, no-any-return]
return super().load_dialect_impl(dialect)


Expand All @@ -33,13 +32,15 @@ class GUID(types.TypeDecorator): # type: ignore
impl = CHAR
cache_ok = True

def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
def load_dialect_impl(self, dialect: Dialect) -> "TypeEngine[Any]":
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
return dialect.type_descriptor(UUID()) # type: ignore[arg-type, no-any-return]
else:
return dialect.type_descriptor(CHAR(32))
return dialect.type_descriptor(CHAR(32)) # type: ignore[arg-type, no-any-return]

def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
def process_bind_param(
self, value: Optional[Union[str, uuid.UUID]], dialect: Dialect
) -> Optional[str]:
if value is None:
return value
elif dialect.name == "postgresql":
Expand All @@ -51,10 +52,12 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
# hexstring
return value.hex

def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
def process_result_value(
self, value: Optional[Union[str, uuid.UUID]], dialect: Dialect
) -> Optional[uuid.UUID]:
if value is None:
return value
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return cast(uuid.UUID, value)

if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return value