Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import logging
import numbers
import os
import re
import typing
Expand Down Expand Up @@ -675,15 +676,23 @@ def _deserialize(self, value, attr, data, **kwargs) -> str:
class DumpableIntegerField(fields.Integer):
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
if self.strict and not isinstance(value, int):
raise ValidationError("Given value is not an integer")
# this implementation can serialize bool to bool
raise self.make_error("invalid", input=value)
return super()._serialize(value, attr, obj, **kwargs)


class DumpableFloatField(fields.Float):
def __init__(self, *, strict: bool = False, allow_nan: bool = False, as_string: bool = False, **kwargs):
self.strict = strict
super().__init__(allow_nan=allow_nan, as_string=as_string, **kwargs)

def _validated(self, value):
if self.strict and not isinstance(value, float):
raise self.make_error("invalid", input=value)
return super()._validated(value)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
if not isinstance(value, float):
raise ValidationError("Given value is not a float")
return super()._serialize(value, attr, obj, **kwargs)
return super()._serialize(self._validated(value), attr, obj, **kwargs)


class DumpableStringField(fields.String):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def InputsField(**kwargs):
# https://github.com/marshmallow-code/marshmallow/pull/755
# Use DumpableIntegerField to make sure there will be validation error when
# loading/dumping a float to int.
# note that this field can serialize bool instance but cannot deserialize bool instance.
DumpableIntegerField(strict=True),
# Use DumpableFloatField to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(),
# Use DumpableFloatField with strict of True to avoid '1'(str) serialized to 1.0(float)
DumpableFloatField(strict=True),
# put string schema after Int and Float to make sure they won't dump to string
fields.Str(),
# fields.Bool comes last since it'll parse anything non-falsy to True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ def test_dump_pipeline_inputs(self):
"integer_input": 15,
"bool_input": False,
"string_input": "hello",
"string_integer_input": "43",
Comment thread
wangchao1230 marked this conversation as resolved.
}

job = load_job(test_path, params_override=[{"inputs": expected_inputs}])
Expand Down