Skip to content

Commit

Permalink
pyflyte run now supports json/yaml files (#1606)
Browse files Browse the repository at this point in the history
* pyflyte run now supports json files

Signed-off-by: Ketan Umare <[email protected]>

* added yaml support

Signed-off-by: Ketan Umare <[email protected]>

* fixed parsing

Signed-off-by: Ketan Umare <[email protected]>

* fixed windows test

Signed-off-by: Ketan Umare <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Apr 26, 2023
1 parent f5c5abe commit 36fc151
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 61 deletions.
158 changes: 110 additions & 48 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import cast

import rich_click as click
import yaml
from dataclasses_json import DataClassJsonMixin
from pytimeparse import parse
from typing_extensions import get_args
Expand Down Expand Up @@ -55,10 +56,6 @@ def remove_prefix(text, prefix):
return text


class JsonParamType(click.ParamType):
name = "json object"


@dataclass
class Directory(object):
dir_path: str
Expand Down Expand Up @@ -134,6 +131,33 @@ def convert(
return datetime.timedelta(seconds=parse(value))


class JsonParamType(click.ParamType):
name = "json object OR json/yaml file path"

def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:
if value is None:
raise click.BadParameter("None value cannot be converted to a Json type.")
if type(value) == dict or type(value) == list:
return value
try:
return json.loads(value)
except Exception: # noqa
try:
# We failed to load the json, so we'll try to load it as a file
if os.path.exists(value):
# if the value is a yaml file, we'll try to load it as yaml
if value.endswith(".yaml") or value.endswith(".yml"):
with open(value, "r") as f:
return yaml.safe_load(f)
with open(value, "r") as f:
return json.load(f)
raise
except json.JSONDecodeError as e:
raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}")


@dataclass
class DefaultConverter(object):
click_type: click.ParamType
Expand Down Expand Up @@ -299,6 +323,68 @@ def convert_to_union(
logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e)
raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}")

def convert_to_list(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: list
) -> Literal:
"""
Convert a python list into a Flyte Literal
"""
if not value:
raise click.BadParameter("Expected non-empty list")
if not isinstance(value, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(value)}")
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.collection_type,
type(value[0]),
self._create_upload_fn,
)
lt = Literal(collection=LiteralCollection([]))
for v in value:
click_val = converter._click_type.convert(v, param, ctx)
lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val))
return lt

def convert_to_map(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: dict
) -> Literal:
"""
Convert a python dict into a Flyte Literal.
It is assumed that the click parameter type is a JsonParamType. The map is also assumed to be univariate.
"""
if not value:
raise click.BadParameter("Expected non-empty dict")
if not isinstance(value, dict):
raise click.BadParameter(f"Expected json dict '{{...}}', parsed value is {type(value)}")
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.map_value_type,
type(value[list(value.keys())[0]]),
self._create_upload_fn,
)
lt = Literal(map=LiteralMap({}))
for k, v in value.items():
click_val = converter._click_type.convert(v, param, ctx)
lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val)
return lt

def convert_to_struct(
self,
ctx: typing.Optional[click.Context],
param: typing.Optional[click.Parameter],
value: typing.Union[dict, typing.Any],
) -> Literal:
"""
Convert the loaded json object to a Flyte Literal struct type.
"""
if type(value) != self._python_type:
o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value))
else:
o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)

def convert_to_literal(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
) -> Literal:
Expand All @@ -309,53 +395,17 @@ def convert_to_literal(
return self.convert_to_blob(ctx, param, value)

if self._literal_type.collection_type:
python_value = json.loads(value) if isinstance(value, str) else value
if not isinstance(python_value, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(python_value)}")
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.collection_type,
type(python_value[0]),
self._create_upload_fn,
)
lt = Literal(collection=LiteralCollection([]))
for v in python_value:
click_val = converter._click_type.convert(v, param, ctx)
lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val))
return lt
return self.convert_to_list(ctx, param, value)

if self._literal_type.map_value_type:
python_value = json.loads(value) if isinstance(value, str) else value
if not isinstance(python_value, dict):
raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(python_value))
converter = FlyteLiteralConverter(
ctx,
self._flyte_ctx,
self._literal_type.map_value_type,
type(python_value[next(iter(python_value))]),
self._create_upload_fn,
)
lt = Literal(map=LiteralMap({}))
for k, v in python_value.items():
click_val = converter._click_type.convert(v, param, ctx)
lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val)
return lt
return self.convert_to_map(ctx, param, value)

if self._literal_type.union_type:
return self.convert_to_union(ctx, param, value)

if self._literal_type.simple or self._literal_type.enum_type:
if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT:
if self._python_type == dict:
if type(value) != str:
# The type of default value is dict, so we have to convert it to json string
value = json.dumps(value)
o = json.loads(value)
elif type(value) != self._python_type:
o = cast(DataClassJsonMixin, self._python_type).from_json(value)
else:
o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)
return self.convert_to_struct(ctx, param, value)
return Literal(scalar=self._converter.convert(value, self._python_type))

if self._literal_type.schema:
Expand All @@ -366,10 +416,15 @@ def convert_to_literal(
)

def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]:
lit = self.convert_to_literal(ctx, param, value)
if not self._remote:
return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type)
return lit
try:
lit = self.convert_to_literal(ctx, param, value)
if not self._remote:
return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type)
return lit
except click.BadParameter:
raise
except Exception as e:
raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e


def to_click_option(
Expand All @@ -392,6 +447,13 @@ def to_click_option(
if literal_converter.is_bool() and not default_val:
default_val = False

if literal_var.type.simple == SimpleType.STRUCT:
if default_val:
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
else:
default_val = cast(DataClassJsonMixin, default_val).to_json()

return click.Option(
param_decls=[f"--{input_name}"],
type=literal_converter.click_type,
Expand Down
61 changes: 48 additions & 13 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import json
import os
import pathlib
import tempfile
import typing
from datetime import datetime, timedelta
from enum import Enum

import click
import mock
import pytest
import yaml
from click.testing import CliRunner

from flytekit import FlyteContextManager
Expand All @@ -22,6 +24,7 @@
DurationParamType,
FileParamType,
FlyteLiteralConverter,
JsonParamType,
get_entities_in_file,
run_command,
)
Expand Down Expand Up @@ -155,19 +158,19 @@ def test_union_type2(input):

def test_union_type_with_invalid_input():
runner = CliRunner()
with pytest.raises(ValueError, match="Failed to convert python type typing.Union"):
runner.invoke(
pyflyte.main,
[
"--verbose",
"run",
os.path.join(DIR_NAME, "workflow.py"),
"test_union2",
"--a",
"hello",
],
catch_exceptions=False,
)
result = runner.invoke(
pyflyte.main,
[
"--verbose",
"run",
os.path.join(DIR_NAME, "workflow.py"),
"test_union2",
"--a",
"hello",
],
catch_exceptions=False,
)
assert result.exit_code == 2


def test_get_entities_in_file():
Expand Down Expand Up @@ -223,6 +226,7 @@ def test_list_default_arguments(wf_path):
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0


Expand Down Expand Up @@ -387,3 +391,34 @@ def test_datetime_type():
v = t.convert("now", None, None)
assert v.day == now.day
assert v.month == now.month


def test_json_type():
t = JsonParamType()
assert t.convert(value='{"a": "b"}', param=None, ctx=None) == {"a": "b"}

with pytest.raises(click.BadParameter):
t.convert(None, None, None)

# test that it loads a json file
with tempfile.NamedTemporaryFile("w", delete=False) as f:
json.dump({"a": "b"}, f)
f.flush()
assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"}

# test that if the file is not a valid json, it raises an error
with tempfile.NamedTemporaryFile("w", delete=False) as f:
f.write("asdf")
f.flush()
with pytest.raises(click.BadParameter):
t.convert(value=f.name, param="asdf", ctx=None)

# test if the file does not exist
with pytest.raises(click.BadParameter):
t.convert(value="asdf", param=None, ctx=None)

# test if the file is yaml and ends with .yaml it works correctly
with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f:
yaml.dump({"a": "b"}, f)
f.flush()
assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"}

0 comments on commit 36fc151

Please sign in to comment.