Skip to content

Commit

Permalink
fix(OpenAPI): Correctly handle typing.NewType (#3580)
Browse files Browse the repository at this point in the history
* Unwrap NewType for OpenAPI schema
* Support nested NewType

(cherry picked from commit 2e4b820)
  • Loading branch information
provinzkraut committed Jun 21, 2024
1 parent 4b06e76 commit 3cf5e32
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
14 changes: 13 additions & 1 deletion litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from litestar.utils.typing import (
get_origin_or_inner_type,
make_non_optional_union,
unwrap_new_type,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -325,7 +326,9 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

result: Schema | Reference

if plugin_for_annotation := self.get_plugin_for(field_definition):
if field_definition.is_new_type:
result = self.for_new_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
Expand Down Expand Up @@ -354,6 +357,15 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re

return self.process_schema_result(field_definition, result) if isinstance(result, Schema) else result

def for_new_type(self, field_definition: FieldDefinition) -> Schema | Reference:
return self.for_field_definition(
FieldDefinition.from_kwarg(
annotation=unwrap_new_type(field_definition.raw),
name=field_definition.name,
default=field_definition.default,
)
)

@staticmethod
def for_upload_file(field_definition: FieldDefinition) -> Schema:
"""Create schema for UploadFile.
Expand Down
27 changes: 12 additions & 15 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from inspect import Parameter, Signature
from typing import (
Any,
AnyStr,
Callable,
Collection,
ForwardRef,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
)
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast

from msgspec import UnsetType
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict
from typing_extensions import NewType, NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict

from litestar.exceptions import ImproperlyConfiguredException, LitestarWarning
from litestar.openapi.spec import Example
Expand Down Expand Up @@ -314,7 +302,12 @@ def is_generic(self) -> bool:
def is_simple_type(self) -> bool:
"""Check if the field type is a singleton value (e.g. int, str etc.)."""
return not (
self.is_generic or self.is_optional or self.is_union or self.is_mapping or self.is_non_string_iterable
self.is_generic
or self.is_optional
or self.is_union
or self.is_mapping
or self.is_non_string_iterable
or self.is_new_type
)

@property
Expand Down Expand Up @@ -366,6 +359,10 @@ def is_tuple(self) -> bool:
"""Whether the annotation is a ``tuple`` or not."""
return self.is_subclass_of(tuple)

@property
def is_new_type(self) -> bool:
return isinstance(self.annotation, NewType)

@property
def is_type_var(self) -> bool:
"""Whether the annotation is a TypeVar or not."""
Expand Down
10 changes: 9 additions & 1 deletion litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
cast,
)

from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin, get_type_hints
from typing_extensions import Annotated, NewType, NotRequired, Required, get_args, get_origin, get_type_hints

from litestar.types.builtin_types import NoneType, UnionTypes

Expand Down Expand Up @@ -174,6 +174,14 @@ def unwrap_annotation(annotation: Any) -> tuple[Any, tuple[Any, ...], set[Any]]:
return annotation, tuple(metadata), wrappers


def unwrap_new_type(new_type: Any) -> Any:
"""Unwrap a (nested) ``typing.NewType``"""
inner = new_type
while isinstance(inner, NewType):
inner = inner.__supertype__
return inner


def get_origin_or_inner_type(annotation: Any) -> Any:
"""Get origin or unwrap it. Returns None for non-generic types.
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

import pytest
from typing_extensions import Annotated
from typing_extensions import Annotated, NewType

from litestar import Controller, Litestar, Router, get
from litestar._openapi.datastructures import OpenAPIContext
Expand Down Expand Up @@ -380,3 +380,41 @@ async def uuid_path(id: Annotated[UUID, Parameter(description="UUID ID")]) -> UU
response = client.get("/schema/openapi.json")
assert response.json()["paths"]["/str/{id}"]["get"]["parameters"][0]["description"] == "String ID"
assert response.json()["paths"]["/uuid/{id}"]["get"]["parameters"][0]["description"] == "UUID ID"


def test_unwrap_new_type() -> None:
FancyString = NewType("FancyString", str)

@get("/{path_param:str}")
async def handler(
param: FancyString,
optional_param: Optional[FancyString],
path_param: FancyString,
) -> FancyString:
return FancyString("")

app = Litestar([handler])
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
Schema(type=OpenAPIType.NULL),
Schema(type=OpenAPIType.STRING),
]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert (
app.openapi_schema.paths["/{path_param}"].get.responses["200"].content["application/json"].schema.type # type: ignore[index, union-attr]
== OpenAPIType.STRING
)


def test_unwrap_nested_new_type() -> None:
FancyString = NewType("FancyString", str)
FancierString = NewType("FancierString", FancyString) # pyright: ignore

@get("/")
async def handler(
param: FancierString,
) -> None:
return None

app = Litestar([handler])
assert app.openapi_schema.paths["/"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]

0 comments on commit 3cf5e32

Please sign in to comment.