Skip to content

Commit

Permalink
fix(sdk): Get short name of complex input/output types to ensure we c…
Browse files Browse the repository at this point in the history
…an map to appropriate de|serializer (#6504)

Also:
- Simplify _data_passing methods, add in type hints and docstrings.
- Remove get_deserializer_code_for_type.
  • Loading branch information
alexlatchford authored Sep 10, 2021
1 parent 343350a commit a0b18eb
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 31 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* Fix passing in "" to a str parameter causes the parameter to receive it as None instead. [\#6533](https://github.com/kubeflow/pipelines/pull/6533)
* Depends on `kfp-pipeline-spec>=0.1.10,<0.2.0` [\#6515](https://github.com/kubeflow/pipelines/pull/6515)
* Depends on kubernetes>=8.0.0,<19. [\#6532](https://github.com/kubeflow/pipelines/pull/6532)
* Get short name of complex input/output types to ensure we can map to appropriate de|serializer. [\#6504](https://github.com/kubeflow/pipelines/pull/6504)

## Documentation Updates

Expand Down
7 changes: 3 additions & 4 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ._naming import _sanitize_file_name, _sanitize_python_function_name, generate_unique_name_conversion_table
from ._yaml_utils import load_yaml
from .structures import *
from ._data_passing import serialize_value, get_canonical_type_for_type_struct
from ._data_passing import serialize_value, get_canonical_type_for_type_name

_default_component_name = 'Component'

Expand Down Expand Up @@ -393,9 +393,8 @@ def component_default_to_func_default(component_default: str,
input_parameters = [
_dynamic.KwParameter(
input_name_to_pythonic[port.name],
annotation=(get_canonical_type_for_type_struct(str(port.type)) or
str(port.type)
if port.type else inspect.Parameter.empty),
annotation=(get_canonical_type_for_type_name(str(port.type)) or str(
port.type) if port.type else inspect.Parameter.empty),
default=component_default_to_func_default(port.default,
port.optional),
) for port in reordered_input_list
Expand Down
63 changes: 44 additions & 19 deletions sdk/python/kfp/components/_data_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
# limitations under the License.

__all__ = [
'get_canonical_type_struct_for_type',
'get_canonical_type_for_type_struct',
'get_deserializer_code_for_type',
'get_deserializer_code_for_type_struct',
'get_serializer_func_for_type_struct',
'get_canonical_type_name_for_type',
'get_canonical_type_for_type_name',
'get_deserializer_code_for_type_name',
'get_serializer_func_for_type_name',
]

import inspect
from typing import Any, Callable, NamedTuple, Sequence
from typing import Any, Callable, NamedTuple, Optional, Sequence, Type
import warnings

from kfp.components import type_annotation_utils

Converter = NamedTuple('Converter', [
('types', Sequence[str]),
('types', Sequence[Type]),
('type_names', Sequence[str]),
('serializer', Callable[[Any], str]),
('deserializer_code', str),
Expand Down Expand Up @@ -155,38 +154,64 @@ def _deserialize_base64_pickle(s):
}


def get_canonical_type_struct_for_type(typ) -> str:
def get_canonical_type_name_for_type(typ: Type) -> str:
"""Find the canonical type name for a given type.
Args:
typ: The type to search for.
Returns:
The canonical name of the type found.
"""
try:
return type_to_type_name.get(typ, None)
except:
return None


def get_canonical_type_for_type_struct(type_struct) -> str:
def get_canonical_type_for_type_name(type_name: str) -> Optional[Type]:
"""Find the canonical type for a given type name.
Args:
type_name: The type name to search for.
Returns:
The canonical type found.
"""
try:
return type_name_to_type.get(type_struct, None)
return type_name_to_type.get(type_name, None)
except:
return None


def get_deserializer_code_for_type(typ) -> str:
def get_deserializer_code_for_type_name(type_name: str) -> Optional[str]:
"""Find the deserializer code for the given type name.
Args:
type_name: The type name to search for.
Returns:
The deserializer code needed to deserialize the type.
"""
try:
return type_name_to_deserializer.get(
get_canonical_type_struct_for_type[typ], None)
type_annotation_utils.get_short_type_name(type_name), None)
except:
return None


def get_deserializer_code_for_type_struct(type_struct) -> str:
try:
return type_name_to_deserializer.get(type_struct, None)
except:
return None
def get_serializer_func_for_type_name(type_name: str) -> Optional[Callable]:
"""Find the serializer code for the given type name.
Args:
type_name: The type name to search for.
def get_serializer_func_for_type_struct(type_struct) -> str:
Returns:
The serializer func needed to serialize the type.
"""
try:
return type_name_to_serializer.get(type_struct, None)
return type_name_to_serializer.get(
type_annotation_utils.get_short_type_name(type_name), None)
except:
return None

Expand Down
10 changes: 5 additions & 5 deletions sdk/python/kfp/components/_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from ._yaml_utils import dump_yaml
from ._components import _create_task_factory_from_component_spec
from ._data_passing import serialize_value, get_deserializer_code_for_type_struct, get_serializer_func_for_type_struct, get_canonical_type_struct_for_type
from ._data_passing import serialize_value, get_deserializer_code_for_type_name, get_serializer_func_for_type_name, get_canonical_type_name_for_type
from ._naming import _make_name_unique_by_adding_index
from .structures import *
from . import _structures as structures
Expand Down Expand Up @@ -347,7 +347,7 @@ def annotation_to_type_struct(annotation):
if isinstance(annotation, dict):
return annotation
if isinstance(annotation, type):
type_struct = get_canonical_type_struct_for_type(annotation)
type_struct = get_canonical_type_name_for_type(annotation)
if type_struct:
return type_struct
type_name = str(annotation.__name__)
Expand All @@ -359,7 +359,7 @@ def annotation_to_type_struct(annotation):
type_name = str(annotation)

# It's also possible to get the converter by type name
type_struct = get_canonical_type_struct_for_type(type_name)
type_struct = get_canonical_type_name_for_type(type_name)
if type_struct:
return type_struct
return type_name
Expand Down Expand Up @@ -566,7 +566,7 @@ def _func_to_component_spec(func,
definitions = set()

def get_deserializer_and_register_definitions(type_name):
deserializer_code = get_deserializer_code_for_type_struct(type_name)
deserializer_code = get_deserializer_code_for_type_name(type_name)
if deserializer_code:
(deserializer_code_str, definition_str) = deserializer_code
if definition_str:
Expand Down Expand Up @@ -607,7 +607,7 @@ def get_argparse_type_for_input_file(passing_style):
str(passing_style)))

def get_serializer_and_register_definitions(type_name) -> str:
serializer_func = get_serializer_func_for_type_struct(type_name)
serializer_func = get_serializer_func_for_type_name(type_name)
if serializer_func:
# If serializer is not part of the standard python library, then include its code in the generated program
if hasattr(serializer_func,
Expand Down
5 changes: 2 additions & 3 deletions sdk/python/kfp/v2/components/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def _annotation_to_type_struct(annotation):
if isinstance(annotation, dict):
return annotation
if isinstance(annotation, type):
type_struct = _data_passing.get_canonical_type_struct_for_type(
annotation)
type_struct = _data_passing.get_canonical_type_name_for_type(annotation)
if type_struct:
return type_struct
type_name = str(annotation.__name__)
Expand All @@ -99,7 +98,7 @@ def _annotation_to_type_struct(annotation):
type_name = str(annotation)

# It's also possible to get the converter by type name
type_struct = _data_passing.get_canonical_type_struct_for_type(type_name)
type_struct = _data_passing.get_canonical_type_name_for_type(type_name)
if type_struct:
return type_struct
return type_name
Expand Down

0 comments on commit a0b18eb

Please sign in to comment.