diff --git a/CHANGES.rst b/CHANGES.rst index 2e9bbbef..91f9d108 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,11 @@ Version history =============== +**UNRELEASED** + +- Type annotations for ARRAY column attributes now include the Python type of + the array elements + **3.0.0** - Dropped support for Python 3.8 diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 0fb67001..862c2338 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -43,6 +43,7 @@ from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.type_api import UserDefinedType +from sqlalchemy.types import TypeEngine from .models import ( ColumnAttribute, @@ -63,11 +64,6 @@ uses_default_name, ) -if sys.version_info < (3, 10): - pass -else: - pass - _re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)") _re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2') _re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)") @@ -1201,22 +1197,40 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str: column = column_attr.column rendered_column = self.render_column(column, column_attr.name != column.name) - try: - python_type = column.type.python_type + def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]: + column_type = column.type + pre: list[str] = [] + post_size = 0 + if column.nullable: + self.add_literal_import("typing", "Optional") + pre.append("Optional[") + post_size += 1 + + if isinstance(column_type, ARRAY): + dim = getattr(column_type, "dimensions", None) or 1 + pre.extend("list[" for _ in range(dim)) + post_size += dim + + column_type = column_type.item_type + + return "".join(pre), column_type, "]" * post_size + + def render_python_type(column_type: TypeEngine[Any]) -> str: + python_type = column_type.python_type python_type_name = python_type.__name__ - if python_type.__module__ == "builtins": - column_python_type = python_type_name - else: - python_type_module = python_type.__module__ - column_python_type = f"{python_type_module}.{python_type_name}" + python_type_module = python_type.__module__ + if python_type_module == "builtins": + return python_type_name + + try: self.add_module_import(python_type_module) - except NotImplementedError: - self.add_literal_import("typing", "Any") - column_python_type = "Any" + return f"{python_type_module}.{python_type_name}" + except NotImplementedError: + self.add_literal_import("typing", "Any") + return "Any" - if column.nullable: - self.add_literal_import("typing", "Optional") - column_python_type = f"Optional[{column_python_type}]" + pre, col_type, post = get_type_qualifiers() + column_python_type = f"{pre}{render_python_type(col_type)}{post}" return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}" def render_relationship(self, relationship: RelationshipAttribute) -> str: @@ -1297,8 +1311,7 @@ def render_join(terms: list[JoinType]) -> str: relationship_type: str if relationship.type == RelationshipType.ONE_TO_MANY: - self.add_literal_import("typing", "List") - relationship_type = f"List['{relationship.target.name}']" + relationship_type = f"list['{relationship.target.name}']" elif relationship.type in ( RelationshipType.ONE_TO_ONE, RelationshipType.MANY_TO_ONE, @@ -1310,8 +1323,7 @@ def render_join(terms: list[JoinType]) -> str: self.add_literal_import("typing", "Optional") relationship_type = f"Optional[{relationship_type}]" elif relationship.type == RelationshipType.MANY_TO_MANY: - self.add_literal_import("typing", "List") - relationship_type = f"List['{relationship.target.name}']" + relationship_type = f"list['{relationship.target.name}']" else: self.add_literal_import("typing", "Any") relationship_type = "Any" @@ -1409,13 +1421,6 @@ def collect_imports_for_model(self, model: Model) -> None: if model.relationships: self.add_literal_import("sqlmodel", "Relationship") - for relationship_attr in model.relationships: - if relationship_attr.type in ( - RelationshipType.ONE_TO_MANY, - RelationshipType.MANY_TO_MANY, - ): - self.add_literal_import("typing", "List") - def collect_imports_for_column(self, column: Column[Any]) -> None: super().collect_imports_for_column(column) try: @@ -1487,8 +1492,7 @@ def render_relationship(self, relationship: RelationshipAttribute) -> str: RelationshipType.ONE_TO_MANY, RelationshipType.MANY_TO_MANY, ): - self.add_literal_import("typing", "List") - annotation = f"List[{annotation}]" + annotation = f"list[{annotation}]" else: self.add_literal_import("typing", "Optional") annotation = f"Optional[{annotation}]" diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index d3894c0e..1cf9f0e3 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -101,7 +101,7 @@ def test_onetomany_optional(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ - from typing import List, Optional + from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ @@ -116,7 +116,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='container') @@ -152,8 +152,6 @@ def test_manytomany(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ - from typing import List - from sqlalchemy import Column, ForeignKey, Integer, Table from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ mapped_column, relationship @@ -167,7 +165,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ secondary='container_items', back_populates='container') @@ -176,7 +174,7 @@ class SimpleItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - container: Mapped[List['SimpleContainers']] = \ + container: Mapped[list['SimpleContainers']] = \ relationship('SimpleContainers', secondary='container_items', back_populates='item') @@ -208,7 +206,7 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ - from typing import List, Optional + from typing import Optional from sqlalchemy import ForeignKeyConstraint, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, \ @@ -223,7 +221,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='container') diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 548ec71e..6bdee3d8 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -15,7 +15,7 @@ UniqueConstraint, ) from sqlalchemy.sql.expression import text -from sqlalchemy.types import INTEGER, VARCHAR, Text +from sqlalchemy.types import ARRAY, INTEGER, VARCHAR, Text from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator @@ -123,7 +123,7 @@ def test_onetomany(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -137,7 +137,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='container') @@ -166,7 +166,7 @@ def test_onetomany_selfref(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -184,7 +184,7 @@ class SimpleItems(Base): parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \ remote_side=[id], back_populates='parent_item_reverse') - parent_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ remote_side=[parent_item_id], back_populates='parent_item') """, ) @@ -204,7 +204,7 @@ def test_onetomany_selfref_multi(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -223,12 +223,12 @@ class SimpleItems(Base): parent_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', \ remote_side=[id], foreign_keys=[parent_item_id], back_populates='parent_item_reverse') - parent_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + parent_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ remote_side=[parent_item_id], foreign_keys=[parent_item_id], \ back_populates='parent_item') top_item: Mapped[Optional['SimpleItems']] = relationship('SimpleItems', remote_side=[id], \ foreign_keys=[top_item_id], back_populates='top_item_reverse') - top_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + top_item_reverse: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item') """, ) @@ -258,7 +258,7 @@ def test_onetomany_composite(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKeyConstraint, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -273,7 +273,7 @@ class SimpleContainers(Base): id1: Mapped[int] = mapped_column(Integer, primary_key=True) id2: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='simple_containers') @@ -314,7 +314,7 @@ def test_onetomany_multiref(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -328,9 +328,9 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container') - simple_items_: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items_: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container') @@ -413,7 +413,7 @@ def test_onetomany_noinflect(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -427,7 +427,7 @@ class Fehwiuhfiw(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - oglkrogk: Mapped[List['Oglkrogk']] = relationship('Oglkrogk', \ + oglkrogk: Mapped[list['Oglkrogk']] = relationship('Oglkrogk', \ back_populates='fehwiuhfiw') @@ -461,7 +461,7 @@ def test_onetomany_conflicting_column(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer, Text from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -476,7 +476,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) relationship_: Mapped[Optional[str]] = mapped_column('relationship', Text) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='container') @@ -506,7 +506,7 @@ def test_onetomany_conflicting_relationship(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -520,7 +520,7 @@ class Relationship(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='relationship_') @@ -601,8 +601,6 @@ def test_manytomany(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List - from sqlalchemy import Column, ForeignKey, Integer, Table from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -615,7 +613,7 @@ class LeftTable(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - right: Mapped[List['RightTable']] = relationship('RightTable', \ + right: Mapped[list['RightTable']] = relationship('RightTable', \ secondary='association_table', back_populates='left') @@ -624,7 +622,7 @@ class RightTable(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - left: Mapped[List['LeftTable']] = relationship('LeftTable', \ + left: Mapped[list['LeftTable']] = relationship('LeftTable', \ secondary='association_table', back_populates='right') @@ -657,8 +655,6 @@ def test_manytomany_nobidi(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List - from sqlalchemy import Column, ForeignKey, Integer, Table from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -671,7 +667,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + item: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ secondary='container_items') @@ -705,8 +701,6 @@ def test_manytomany_selfref(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List - from sqlalchemy import Column, ForeignKey, Integer, Table from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -719,12 +713,12 @@ class SimpleItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - parent: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + parent: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id \ == t_child_items.c.child_id, \ secondaryjoin=lambda: SimpleItems.id == \ t_child_items.c.parent_id, back_populates='child') - child: Mapped[List['SimpleItems']] = \ + child: Mapped[list['SimpleItems']] = \ relationship('SimpleItems', secondary='otherschema.child_items', \ primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ @@ -773,8 +767,6 @@ def test_manytomany_composite(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List - from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -788,7 +780,7 @@ class SimpleContainers(Base): id1: Mapped[int] = mapped_column(Integer, primary_key=True) id2: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ secondary='container_items', back_populates='simple_containers') @@ -798,7 +790,7 @@ class SimpleItems(Base): id1: Mapped[int] = mapped_column(Integer, primary_key=True) id2: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_containers: Mapped[List['SimpleContainers']] = \ + simple_containers: Mapped[list['SimpleContainers']] = \ relationship('SimpleContainers', secondary='container_items', \ back_populates='simple_items') @@ -1091,7 +1083,7 @@ def test_foreign_key_schema(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKey, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -1106,7 +1098,7 @@ class OtherItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='other_item') @@ -1444,7 +1436,7 @@ def test_named_foreign_key_constraints(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ -from typing import List, Optional +from typing import Optional from sqlalchemy import ForeignKeyConstraint, Integer from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -1458,7 +1450,7 @@ class SimpleContainers(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) - simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ + simple_items: Mapped[list['SimpleItems']] = relationship('SimpleItems', \ back_populates='container') @@ -1509,3 +1501,34 @@ class Simple(Base): server_default=text("'test'")) """, ) + + +def test_table_with_arrays(generator: CodeGenerator) -> None: + Table( + "with_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("int_items_not_optional", ARRAY(INTEGER()), nullable=False), + Column("str_matrix", ARRAY(VARCHAR(), dimensions=2)), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import ARRAY, INTEGER, Integer, VARCHAR +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class WithItems(Base): + __tablename__ = 'with_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER())) + str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2)) +""", + ) diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py index baf92dd6..5d891d9e 100644 --- a/tests/test_generator_sqlmodel.py +++ b/tests/test_generator_sqlmodel.py @@ -116,7 +116,7 @@ def test_onetomany(generator: CodeGenerator) -> None: validate_code( generator.generate(), """\ - from typing import List, Optional + from typing import Optional from sqlalchemy import Column, ForeignKey, Integer from sqlmodel import Field, Relationship, SQLModel @@ -127,7 +127,7 @@ class SimpleContainers(SQLModel, table=True): id: Optional[int] = Field(default=None, sa_column=Column(\ 'id', Integer, primary_key=True)) - simple_goods: List['SimpleGoods'] = Relationship(\ + simple_goods: list['SimpleGoods'] = Relationship(\ back_populates='container')