From ea62acbe3c40ae84f398b8141f65b89215eca2c7 Mon Sep 17 00:00:00 2001 From: zhangxingzhi Date: Tue, 18 Oct 2022 08:56:51 +0800 Subject: [PATCH 1/3] fix: avoid serialize input value from int string to float --- .../azure/ai/ml/_schema/core/fields.py | 17 +++++++++++++---- .../_schema/job/input_output_fields_provider.py | 5 +++-- .../unittests/test_pipeline_job_schema.py | 1 + 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py index 36ce7848209a..8d1861270425 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py @@ -6,6 +6,7 @@ import copy import logging +import numbers import os import re import typing @@ -675,15 +676,23 @@ def _deserialize(self, value, attr, data, **kwargs) -> str: class DumpableIntegerField(fields.Integer): def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]: if self.strict and not isinstance(value, int): - raise ValidationError("Given value is not an integer") + # this implementation can serialize bool to bool + raise self.make_error("invalid", input=value) return super()._serialize(value, attr, obj, **kwargs) class DumpableFloatField(fields.Float): + def __init__(self, *, strict: bool = False, allow_nan: bool = False, as_string: bool = False, **kwargs): + self.strict = strict + super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs) + + def _validated(self, value): + if self.strict and not isinstance(value, float): + raise self.make_error("invalid", input=value) + return super()._validated(value) + def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]: - if not isinstance(value, float): - raise ValidationError("Given value is not a float") - return super()._serialize(value, attr, obj, **kwargs) + return super()._serialize(self._validated(value), attr, obj, **kwargs) class DumpableStringField(fields.String): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py index 2201191ad6e4..4a2df59b6008 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py @@ -31,9 +31,10 @@ def InputsField(**kwargs): # https://github.com/marshmallow-code/marshmallow/pull/755 # Use DumpableIntegerField to make sure there will be validation error when # loading/dumping a float to int. + # note that this field can serialize bool instance but cannot deserialize bool instance. DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), + # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float) + DumpableFloatField(strict=True), # put string schema after Int and Float to make sure they won't dump to string fields.Str(), # fields.Bool comes last since it'll parse anything non-falsy to True diff --git a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py index 1f70ada8114c..e7b0bfbe2f4f 100644 --- a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py +++ b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py @@ -1729,6 +1729,7 @@ def test_dump_pipeline_inputs(self): "integer_input": 15, "bool_input": False, "string_input": "hello", + "string_integer_input": "43", } job = load_job(test_path, params_override=[{"inputs": expected_inputs}]) From ab7f5af1dca5acc4d8edb43a3be61805284219bc Mon Sep 17 00:00:00 2001 From: zhangxingzhi Date: Tue, 18 Oct 2022 12:01:56 +0800 Subject: [PATCH 2/3] fix: fix ci --- .../ai/ml/_internal/_schema/input_output.py | 28 ++----------- .../job/input_output_fields_provider.py | 42 ++++++++++--------- .../unittests/test_pipeline_job_schema.py | 1 - 3 files changed, 27 insertions(+), 44 deletions(-) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py index a71e354e8e27..3eec5bd89027 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py @@ -6,8 +6,8 @@ from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema -from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField - +from azure.ai.ml._schema.core.fields import DumpableEnumField +from azure.ai.ml._schema.job.input_output_fields_provider import PrimitiveValueField SUPPORTED_INTERNAL_PARAM_TYPES = [ "integer", @@ -74,29 +74,9 @@ class InternalEnumParameterSchema(ParameterSchema): required=True, data_key="type", ) - default = UnionField( - [ - DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ], - ) + default = PrimitiveValueField() enum = fields.List( - UnionField( - [ - DumpableIntegerField(strict=True), - # Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ] - ), + PrimitiveValueField(), required=True, ) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py index 4a2df59b6008..ee4527552a24 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py @@ -23,25 +23,7 @@ def InputsField(**kwargs): NestedField(ModelInputSchema), NestedField(MLTableInputSchema), NestedField(InputLiteralValueSchema), - UnionField( - [ - # Note: order matters here - to make sure value parsed correctly. - # By default when strict is false, marshmallow downcasts float to int. - # Setting it to true will throw a validation error when loading a float to int. - # https://github.com/marshmallow-code/marshmallow/pull/755 - # Use DumpableIntegerField to make sure there will be validation error when - # loading/dumping a float to int. - # note that this field can serialize bool instance but cannot deserialize bool instance. - DumpableIntegerField(strict=True), - # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float) - DumpableFloatField(strict=True), - # put string schema after Int and Float to make sure they won't dump to string - fields.Str(), - # fields.Bool comes last since it'll parse anything non-falsy to True - fields.Bool(), - ], - is_strict=False, - ), + PrimitiveValueField(is_strict=False), # This ordering of types for the values keyword is intentional. The ordering of types # determines what order schema values are matched and cast in. Changing the current ordering can # result in values being mis-cast such as 1.0 translating into True. @@ -60,3 +42,25 @@ def OutputsField(**kwargs): metadata={"description": "Outputs of a job."}, **kwargs ) + + +def PrimitiveValueField(**kwargs): + return UnionField( + [ + # Note: order matters here - to make sure value parsed correctly. + # By default when strict is false, marshmallow downcasts float to int. + # Setting it to true will throw a validation error when loading a float to int. + # https://github.com/marshmallow-code/marshmallow/pull/755 + # Use DumpableIntegerField to make sure there will be validation error when + # loading/dumping a float to int. + # note that this field can serialize bool instance but cannot deserialize bool instance. + DumpableIntegerField(strict=True), + # Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float) + DumpableFloatField(strict=True), + # put string schema after Int and Float to make sure they won't dump to string + fields.Str(), + # fields.Bool comes last since it'll parse anything non-falsy to True + fields.Bool(), + ], + **kwargs, + ) diff --git a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py index e7b0bfbe2f4f..c23da0c9e575 100644 --- a/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py +++ b/sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_schema.py @@ -20,7 +20,6 @@ LearningRateScheduler, StochasticOptimizer, ) -from azure.ai.ml._restclient.v2022_10_01_preview.models import JobService as RestJobService from azure.ai.ml._utils.utils import camel_to_snake, dump_yaml_to_file, is_data_binding_expression, load_yaml from azure.ai.ml.constants._common import ARM_ID_PREFIX from azure.ai.ml.constants._component import ComponentJobConstants From a285419137f9d683801755c3d328af0c09085232 Mon Sep 17 00:00:00 2001 From: zhangxingzhi Date: Tue, 18 Oct 2022 13:00:30 +0800 Subject: [PATCH 3/3] fix: pylint --- sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py index 8d1861270425..8d2e000b9158 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py @@ -6,7 +6,6 @@ import copy import logging -import numbers import os import re import typing