Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Fix Circular imports on inherited discriminators. #17882

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
// TODO: migrate almost (all?) everything to the `PythonImports` class.
TreeSet<String> modelImports = new TreeSet<>();
TreeSet<String> postponedModelImports = new TreeSet<>();
TreeSet<String> discriminatorModelImports = new TreeSet<>();

for (ModelMap m : objs.getModels()) {
TreeSet<String> exampleImports = new TreeSet<>();
Expand Down Expand Up @@ -929,7 +930,7 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
moduleImports.add("typing", "Union");
Set<CodegenDiscriminator.MappedModel> discriminator = model.getDiscriminator().getMappedModels();
for (CodegenDiscriminator.MappedModel mappedModel : discriminator) {
postponedModelImports.add(mappedModel.getModelName());
discriminatorModelImports.add(mappedModel.getMappingName());
}
}
}
Expand Down Expand Up @@ -1032,6 +1033,19 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {

model.getVendorExtensions().putIfAbsent("x-py-postponed-model-imports", modelsToImport);
}

if (!discriminatorModelImports.isEmpty()) {
Set<String> modelsToImport = new TreeSet<>();
for (String modelImport : discriminatorModelImports) {
if (modelImport.equals(model.classname)) {
// skip self import
continue;
}
modelsToImport.add("globals()[\"" + modelImport + "\"] = importlib.import_module(\"" + packageName + ".models." + underscore(modelImport) + "\")." + modelImport);
}

model.discriminator.getVendorExtensions().putIfAbsent("x-py-discriminator-model-imports", modelsToImport);
}
}

return objs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from __future__ import annotations
import pprint
import re # noqa: F401
import json
import importlib

{{#vendorExtensions.x-py-other-imports}}
{{{.}}}
Expand Down Expand Up @@ -88,16 +89,27 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}
__discriminator_property_name: ClassVar[str] = '{{discriminator.propertyBaseName}}'

# discriminator mappings
__discriminator_value_class_map: ClassVar[Dict[str, str]] = {
{{#mappedModels}}'{{{mappingName}}}': '{{{modelName}}}'{{^-last}},{{/-last}}{{/mappedModels}}
}
__discriminator_value_class_map: ClassVar[Union[Dict[str, str], None]] = None

@classmethod
def _get_discriminator_value_class_map(cls) -> ClassVar[Dict[str, str]]:
if cls.__discriminator_value_class_map == None:
# Prevent circular imports caused by mutually referencing classes
{{#vendorExtensions.x-py-discriminator-model-imports}}
{{{.}}}
{{/vendorExtensions.x-py-discriminator-model-imports}}

cls.__discriminator_value_class_map = {
{{#mappedModels}}'{{{mappingName}}}': '{{{modelName}}}'{{^-last}},{{/-last}}{{/mappedModels}}
}
return cls.__discriminator_value_class_map

@classmethod
def get_discriminator_value(cls, obj: Dict[str, Any]) -> Optional[str]:
"""Returns the discriminator value (object type) of the data"""
discriminator_value = obj[cls.__discriminator_property_name]
if discriminator_value:
return cls.__discriminator_value_class_map.get(discriminator_value)
return cls._get_discriminator_value_class_map().get(discriminator_value)
else:
return None

Expand Down Expand Up @@ -247,7 +259,7 @@ class {{classname}}({{#parent}}{{{.}}}{{/parent}}{{^parent}}BaseModel{{/parent}}
else:
raise ValueError("{{{classname}}} failed to lookup discriminator value from " +
json.dumps(obj) + ". Discriminator property name: " + cls.__discriminator_property_name +
", mapping: " + json.dumps(cls.__discriminator_value_class_map))
", mapping: " + json.dumps(cls._get_discriminator_value_class_map()))
{{/discriminator}}
{{/hasChildren}}
{{^hasChildren}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from datetime import datetime
from pydantic import Field, StrictStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictFloat, StrictInt
from typing import Any, ClassVar, Dict, List, Optional, Union
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from datetime import datetime
from pydantic import Field, StrictStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictFloat, StrictInt
from typing import Any, ClassVar, Dict, List, Optional, Union
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictStr
from typing import Any, ClassVar, Dict, List, Optional, Union
Expand All @@ -41,16 +42,26 @@ class Animal(BaseModel):
__discriminator_property_name: ClassVar[str] = 'className'

# discriminator mappings
__discriminator_value_class_map: ClassVar[Dict[str, str]] = {
'Cat': 'Cat','Dog': 'Dog'
}
__discriminator_value_class_map: ClassVar[Union[Dict[str, str], None]] = None

@classmethod
def _get_discriminator_value_class_map(cls) -> ClassVar[Dict[str, str]]:
if cls.__discriminator_value_class_map == None:
# Prevent circular imports caused by mutually referencing classes
globals()["Cat"] = importlib.import_module("petstore_api.models.cat").Cat
globals()["Dog"] = importlib.import_module("petstore_api.models.dog").Dog

cls.__discriminator_value_class_map = {
'Cat': 'Cat','Dog': 'Dog'
}
return cls.__discriminator_value_class_map

@classmethod
def get_discriminator_value(cls, obj: Dict[str, Any]) -> Optional[str]:
"""Returns the discriminator value (object type) of the data"""
discriminator_value = obj[cls.__discriminator_property_name]
if discriminator_value:
return cls.__discriminator_value_class_map.get(discriminator_value)
return cls._get_discriminator_value_class_map().get(discriminator_value)
else:
return None

Expand Down Expand Up @@ -99,10 +110,6 @@ def from_dict(cls, obj: Dict[str, Any]) -> Optional[Union[Self, Self]]:
else:
raise ValueError("Animal failed to lookup discriminator value from " +
json.dumps(obj) + ". Discriminator property name: " + cls.__discriminator_property_name +
", mapping: " + json.dumps(cls.__discriminator_value_class_map))
", mapping: " + json.dumps(cls._get_discriminator_value_class_map()))

from petstore_api.models.cat import Cat
from petstore_api.models.dog import Dog
# TODO: Rewrite to not use raise_errors
Animal.model_rebuild(raise_errors=False)

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictInt, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictStr
from typing import Any, ClassVar, Dict, List
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, StrictStr, field_validator
from typing import Any, ClassVar, Dict, List
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re # noqa: F401
import json
import importlib

from pydantic import BaseModel, Field, StrictStr
from typing import Any, ClassVar, Dict, List, Optional
Expand Down
Loading