Skip to content

Commit 87a6749

Browse files
fix: Dataclass field type not used correctly (#371)
Co-authored-by: Andrew Truong <[email protected]>
1 parent f67f36e commit 87a6749

File tree

3 files changed

+79
-7
lines changed

3 files changed

+79
-7
lines changed

polyfactory/factories/dataclass_factory.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
from dataclasses import MISSING, fields, is_dataclass
4-
from typing import Any, Generic
5-
6-
from typing_extensions import TypeGuard, get_type_hints
4+
from typing import TYPE_CHECKING, Any, ForwardRef, Generic
75

86
from polyfactory.factories.base import BaseFactory, T
97
from polyfactory.field_meta import FieldMeta, Null
8+
from polyfactory.utils.helpers import evaluate_forwardref
9+
10+
if TYPE_CHECKING:
11+
from typing_extensions import TypeGuard
1012

1113

1214
class DataclassFactory(Generic[T], BaseFactory[T]):
@@ -33,8 +35,6 @@ def get_model_fields(cls) -> list["FieldMeta"]:
3335
"""
3436
fields_meta: list["FieldMeta"] = []
3537

36-
model_type_hints = get_type_hints(cls.__model__, include_extras=True)
37-
3838
for field in fields(cls.__model__): # type: ignore[arg-type]
3939
if field.default_factory and field.default_factory is not MISSING:
4040
default_value = field.default_factory()
@@ -43,9 +43,16 @@ def get_model_fields(cls) -> list["FieldMeta"]:
4343
else:
4444
default_value = Null
4545

46+
if isinstance(field.type, ForwardRef):
47+
annotation = evaluate_forwardref(field.type) # type: ignore[unreachable]
48+
elif isinstance(field.type, str):
49+
annotation = evaluate_forwardref(ForwardRef(field.type)) # type: ignore[unreachable]
50+
else:
51+
annotation = field.type
52+
4653
fields_meta.append(
4754
FieldMeta.from_type(
48-
annotation=model_type_hints[field.name],
55+
annotation=annotation,
4956
name=field.name,
5057
default=default_value,
5158
random=cls.__random__,

polyfactory/utils/helpers.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import sys
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, ForwardRef
55

66
from typing_extensions import get_args, get_origin
77

@@ -130,3 +130,17 @@ def normalize_annotation(annotation: Any, random: Random) -> Any:
130130
return origin[args] if origin is not type else annotation
131131

132132
return origin
133+
134+
135+
def evaluate_forwardref(ref: ForwardRef) -> Any:
136+
"""Evaluate ForwardRef to get type annotation
137+
138+
:param ref: A ForwardRef object
139+
140+
:returns: A type annotation
141+
142+
"""
143+
if sys.version_info < (3, 9):
144+
return ref._evaluate(globals(), locals())
145+
146+
return ref._evaluate(globals(), locals(), frozenset())

tests/test_dataclass_factory.py

+51
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,27 @@ class MyFactory(DataclassFactory[PydanticDC]):
5555
assert result.constrained_field >= 100
5656

5757

58+
def test_factory_sqlalchemy_dc() -> None:
59+
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
60+
61+
class SqlAlchemyDCBase(MappedAsDataclass, DeclarativeBase):
62+
pass
63+
64+
class SqlAlchemyDC(SqlAlchemyDCBase):
65+
__tablename__ = "foo"
66+
67+
id: Mapped[int] = mapped_column(primary_key=True)
68+
name: Mapped[str]
69+
70+
class SqlAlchemyDCFactory(DataclassFactory[SqlAlchemyDC]):
71+
__model__ = SqlAlchemyDC
72+
73+
result = SqlAlchemyDCFactory.build()
74+
75+
assert isinstance(result.id, int)
76+
assert isinstance(result.name, str)
77+
78+
5879
def test_vanilla_dc_with_embedded_model() -> None:
5980
@vanilla_dataclass
6081
class VanillaDC:
@@ -168,6 +189,36 @@ class MyFactory(DataclassFactory[example]): # type:ignore[valid-type]
168189
assert MyFactory.process_kwargs() == {"foo": ANY}
169190

170191

192+
def test_sqlalchemy_dc_factory_with_future_annotations(create_module: Callable[[str], ModuleType]) -> None:
193+
module = create_module(
194+
"""
195+
from __future__ import annotations
196+
197+
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
198+
199+
class SqlAlchemyDCBase(MappedAsDataclass, DeclarativeBase): # type: ignore
200+
pass
201+
202+
class SqlAlchemyDC(SqlAlchemyDCBase):
203+
__tablename__ = "foo"
204+
205+
id: Mapped[int] = mapped_column(primary_key=True)
206+
name: Mapped[str]
207+
208+
"""
209+
)
210+
211+
SqlAlchemyDC: type = module.SqlAlchemyDC
212+
assert SqlAlchemyDC.__annotations__ == {"id": "Mapped[int]", "name": "Mapped[str]"}
213+
214+
class SqlAlchemyDCFactory(DataclassFactory[SqlAlchemyDC]): # type:ignore[valid-type]
215+
__model__ = SqlAlchemyDC
216+
217+
result = SqlAlchemyDCFactory.build()
218+
assert isinstance(result.id, int)
219+
assert isinstance(result.name, str)
220+
221+
171222
def test_variable_length_tuple_generation__many_type_args() -> None:
172223
@vanilla_dataclass
173224
class VanillaDC:

0 commit comments

Comments
 (0)