diff --git a/unsloth_zoo/temporary_patches/utils.py b/unsloth_zoo/temporary_patches/utils.py index b35ecac7b..97a6ed58e 100644 --- a/unsloth_zoo/temporary_patches/utils.py +++ b/unsloth_zoo/temporary_patches/utils.py @@ -292,6 +292,9 @@ def _canonicalize_annotation(annotation: Any) -> Any: if origin is not None: args = t.get_args(annotation) args = tuple(canonicalize_annotation(arg) for arg in args) + # Map origin to canonical form (e.g., types.UnionType -> typing.Union) + # so that `int | str` and `Union[int, str]` are equivalent + origin = TYPE_MAPPINGS.get(origin, origin) return (origin, args) return TYPE_MAPPINGS.get(annotation, annotation) pass