Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Version history
===============

**UNRELEASED**

- Type annotations for ARRAY column attributes now include the Python type of
the array elements

**3.0.0**

- Dropped support for Python 3.8
Expand Down
66 changes: 35 additions & 31 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import UserDefinedType
from sqlalchemy.types import TypeEngine

from .models import (
ColumnAttribute,
Expand All @@ -63,11 +64,6 @@
uses_default_name,
)

if sys.version_info < (3, 10):
pass
else:
pass

_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
Expand Down Expand Up @@ -1201,22 +1197,40 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
column = column_attr.column
rendered_column = self.render_column(column, column_attr.name != column.name)

try:
python_type = column.type.python_type
def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
column_type = column.type
pre: list[str] = []
post_size = 0
if column.nullable:
self.add_literal_import("typing", "Optional")
pre.append("Optional[")
post_size += 1

if isinstance(column_type, ARRAY):
dim = getattr(column_type, "dimensions", None) or 1
pre.extend("list[" for _ in range(dim))
post_size += dim

column_type = column_type.item_type

return "".join(pre), column_type, "]" * post_size

def render_python_type(column_type: TypeEngine[Any]) -> str:
python_type = column_type.python_type
python_type_name = python_type.__name__
if python_type.__module__ == "builtins":
column_python_type = python_type_name
else:
python_type_module = python_type.__module__
column_python_type = f"{python_type_module}.{python_type_name}"
python_type_module = python_type.__module__
if python_type_module == "builtins":
return python_type_name

try:
self.add_module_import(python_type_module)
except NotImplementedError:
self.add_literal_import("typing", "Any")
column_python_type = "Any"
return f"{python_type_module}.{python_type_name}"
except NotImplementedError:
self.add_literal_import("typing", "Any")
return "Any"

if column.nullable:
self.add_literal_import("typing", "Optional")
column_python_type = f"Optional[{column_python_type}]"
pre, col_type, post = get_type_qualifiers()
column_python_type = f"{pre}{render_python_type(col_type)}{post}"
return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"

def render_relationship(self, relationship: RelationshipAttribute) -> str:
Expand Down Expand Up @@ -1297,8 +1311,7 @@ def render_join(terms: list[JoinType]) -> str:

relationship_type: str
if relationship.type == RelationshipType.ONE_TO_MANY:
self.add_literal_import("typing", "List")
relationship_type = f"List['{relationship.target.name}']"
relationship_type = f"list['{relationship.target.name}']"
elif relationship.type in (
RelationshipType.ONE_TO_ONE,
RelationshipType.MANY_TO_ONE,
Expand All @@ -1310,8 +1323,7 @@ def render_join(terms: list[JoinType]) -> str:
self.add_literal_import("typing", "Optional")
relationship_type = f"Optional[{relationship_type}]"
elif relationship.type == RelationshipType.MANY_TO_MANY:
self.add_literal_import("typing", "List")
relationship_type = f"List['{relationship.target.name}']"
relationship_type = f"list['{relationship.target.name}']"
else:
self.add_literal_import("typing", "Any")
relationship_type = "Any"
Expand Down Expand Up @@ -1409,13 +1421,6 @@ def collect_imports_for_model(self, model: Model) -> None:
if model.relationships:
self.add_literal_import("sqlmodel", "Relationship")

for relationship_attr in model.relationships:
if relationship_attr.type in (
RelationshipType.ONE_TO_MANY,
RelationshipType.MANY_TO_MANY,
):
self.add_literal_import("typing", "List")

def collect_imports_for_column(self, column: Column[Any]) -> None:
super().collect_imports_for_column(column)
try:
Expand Down Expand Up @@ -1487,8 +1492,7 @@ def render_relationship(self, relationship: RelationshipAttribute) -> str:
RelationshipType.ONE_TO_MANY,
RelationshipType.MANY_TO_MANY,
):
self.add_literal_import("typing", "List")
annotation = f"List[{annotation}]"
annotation = f"list[{annotation}]"
else:
self.add_literal_import("typing", "Optional")
annotation = f"Optional[{annotation}]"
Expand Down
14 changes: 6 additions & 8 deletions tests/test_generator_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_onetomany_optional(generator: CodeGenerator) -> None:
validate_code(
generator.generate(),
"""\
from typing import List, Optional
from typing import Optional

from sqlalchemy import ForeignKey, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
Expand All @@ -116,7 +116,7 @@ class SimpleContainers(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)

simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
back_populates='container')


Expand Down Expand Up @@ -152,8 +152,6 @@ def test_manytomany(generator: CodeGenerator) -> None:
validate_code(
generator.generate(),
"""\
from typing import List

from sqlalchemy import Column, ForeignKey, Integer, Table
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
mapped_column, relationship
Expand All @@ -167,7 +165,7 @@ class SimpleContainers(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)

item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
secondary='container_items', back_populates='container')


Expand All @@ -176,7 +174,7 @@ class SimpleItems(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)

container: Mapped[List['SimpleContainers']] = \
container: Mapped[list['SimpleContainers']] = \
relationship('SimpleContainers', secondary='container_items', back_populates='item')


Expand Down Expand Up @@ -208,7 +206,7 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None:
validate_code(
generator.generate(),
"""\
from typing import List, Optional
from typing import Optional

from sqlalchemy import ForeignKeyConstraint, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \
Expand All @@ -223,7 +221,7 @@ class SimpleContainers(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)

simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \
simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \
back_populates='container')


Expand Down
Loading