diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py index a71e354e8e27..3eec5bd89027 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py @@ -6,8 +6,8 @@ from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema -from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField - +from azure.ai.ml._schema.core.fields import DumpableEnumField +from azure.ai.ml._schema.job.input_output_fields_provider import PrimitiveValueField SUPPORTED_INTERNAL_PARAM_TYPES = [ "integer", @@ -74,29 +74,9 @@ class InternalEnumParameterSchema(ParameterSchema): required=True, data_key="type", ) - default = UnionField( - [ - DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ], - ) + default = PrimitiveValueField() enum = fields.List( - UnionField( - [ - DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ] - ), + PrimitiveValueField(), required=True, ) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py index 36ce7848209a..8d2e000b9158 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py @@ -675,15 +675,23 @@ def _deserialize(self, value, attr, data, **kwargs) -> str: class DumpableIntegerField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]: if self.strict and not isinstance(value, int): - raise ValidationError("Given value is not an integer") + # this implementation can serialize bool to bool + raise self.make_error("invalid", input=value) return super()._serialize(value, attr, obj, **kwargs) class DumpableFloatField(fields.Float): + def __init__(self, *, strict: bool = False, allow_nan: bool = False, as_string: bool = False, **kwargs): + self.strict = strict + super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs) + + def _validated(self, value): + if self.strict and not isinstance(value, float): + raise self.make_error("invalid", input=value) + return super()._validated(value) + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]: - if not isinstance(value, float): - raise ValidationError("Given value is not a float") - return super()._serialize(value, attr, obj, **kwargs) + return super()._serialize(self._validated(value), attr, obj, **kwargs) class DumpableStringField(fields.String): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py index 2201191ad6e4..ee4527552a24 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py @@ -23,24 +23,7 @@ def InputsField(**kwargs): NestedField(ModelInputSchema), NestedField(MLTableInputSchema), NestedField(InputLiteralValueSchema), - UnionField( - [ - # Note: order matters here - to make sure value parsed correctly. - # By default when strict is false, marshmallow downcasts float to int. - # Setting it to true will throw a validation error when loading a float to int. - # https://github.com/marshmallow-code/marshmallow/pull/755 - # Use DumpableIntegerField to make sure there will be validation error when - # loading/dumping a float to int. - DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ], - is_strict=False, - ), + PrimitiveValueField(is_strict=False), # This ordering of types for the values keyword is intentional. The ordering of types # determines what order schema values are matched and cast in. Changing the current ordering can # result in values being mis-cast such as 1.0 translating into True. @@ -59,3 +42,25 @@ def OutputsField(**kwargs): metadata={"description": "Outputs of a job."}, **kwargs ) + + +def PrimitiveValueField(**kwargs): + return UnionField( + [ + # Note: order matters here - to make sure value parsed correctly. + # By default when strict is false, marshmallow downcasts float to int. + # Setting it to true will throw a validation error when loading a float to int. + # https://github.com/marshmallow-code/marshmallow/pull/755 + # Use DumpableIntegerField to make sure there will be validation error when + # loading/dumping a float to int. + # note that this field can serialize bool instance but cannot deserialize bool instance. + DumpableIntegerField(strict=True), + # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float) + DumpableFloatField(strict=True), + # put string schema after Int and Float to make sure they won't dump to string + fields.Str(), + # fields.Bool comes last since it'll parse anything non-falsy to True + fields.Bool(), + ], + **kwargs, + ) diff --git a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py index 1f70ada8114c..c23da0c9e575 100644 --- a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py +++ b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py @@ -20,7 +20,6 @@ LearningRateScheduler, StochasticOptimizer, ) -from azure.ai.ml._restclient.v2022_10_01_preview.models import JobService as RestJobService from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_data_binding_expression, load_yaml from azure.ai.ml.constants._common import ARM_ID_PREFIX from azure.ai.ml.constants._component import ComponentJobConstants @@ -1729,6 +1728,7 @@ def test_dump_pipeline_inputs(self): "integer_input": 15, "bool_input": False, "string_input": "hello", + "string_integer_input": "43", } job = load_job(test_path, params_override=[{"inputs": expected_inputs}])