@@ -55,6 +55,27 @@ class MyFactory(DataclassFactory[PydanticDC]):
55
55
assert result .constrained_field >= 100
56
56
57
57
58
+ def test_factory_sqlalchemy_dc () -> None :
59
+ from sqlalchemy .orm import DeclarativeBase , Mapped , MappedAsDataclass , mapped_column
60
+
61
+ class SqlAlchemyDCBase (MappedAsDataclass , DeclarativeBase ):
62
+ pass
63
+
64
+ class SqlAlchemyDC (SqlAlchemyDCBase ):
65
+ __tablename__ = "foo"
66
+
67
+ id : Mapped [int ] = mapped_column (primary_key = True )
68
+ name : Mapped [str ]
69
+
70
+ class SqlAlchemyDCFactory (DataclassFactory [SqlAlchemyDC ]):
71
+ __model__ = SqlAlchemyDC
72
+
73
+ result = SqlAlchemyDCFactory .build ()
74
+
75
+ assert isinstance (result .id , int )
76
+ assert isinstance (result .name , str )
77
+
78
+
58
79
def test_vanilla_dc_with_embedded_model () -> None :
59
80
@vanilla_dataclass
60
81
class VanillaDC :
@@ -168,6 +189,36 @@ class MyFactory(DataclassFactory[example]): # type:ignore[valid-type]
168
189
assert MyFactory .process_kwargs () == {"foo" : ANY }
169
190
170
191
192
+ def test_sqlalchemy_dc_factory_with_future_annotations (create_module : Callable [[str ], ModuleType ]) -> None :
193
+ module = create_module (
194
+ """
195
+ from __future__ import annotations
196
+
197
+ from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
198
+
199
+ class SqlAlchemyDCBase(MappedAsDataclass, DeclarativeBase): # type: ignore
200
+ pass
201
+
202
+ class SqlAlchemyDC(SqlAlchemyDCBase):
203
+ __tablename__ = "foo"
204
+
205
+ id: Mapped[int] = mapped_column(primary_key=True)
206
+ name: Mapped[str]
207
+
208
+ """
209
+ )
210
+
211
+ SqlAlchemyDC : type = module .SqlAlchemyDC
212
+ assert SqlAlchemyDC .__annotations__ == {"id" : "Mapped[int]" , "name" : "Mapped[str]" }
213
+
214
+ class SqlAlchemyDCFactory (DataclassFactory [SqlAlchemyDC ]): # type:ignore[valid-type]
215
+ __model__ = SqlAlchemyDC
216
+
217
+ result = SqlAlchemyDCFactory .build ()
218
+ assert isinstance (result .id , int )
219
+ assert isinstance (result .name , str )
220
+
221
+
171
222
def test_variable_length_tuple_generation__many_type_args () -> None :
172
223
@vanilla_dataclass
173
224
class VanillaDC :
0 commit comments