Skip to content

Commit f759a3c

Browse files
Async type engine (#2752)
Signed-off-by: Yee Hing Tong <[email protected]>
1 parent 11585d1 commit f759a3c

23 files changed

+429
-127
lines changed

Diff for: flytekit/core/array_node_map_task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from flytekit.tools.module_loader import load_object_from_module
3030
from flytekit.types.pickle import pickle
3131
from flytekit.types.pickle.pickle import FlytePickleTransformer
32+
from flytekit.utils.asyn import loop_manager
3233

3334

3435
class ArrayNodeMapTask(PythonTask):
@@ -253,7 +254,7 @@ def _literal_map_to_python_input(
253254
v = literal_map.literals[k]
254255
# If the input is offloaded, we need to unwrap it
255256
if v.offloaded_metadata:
256-
v = TypeEngine.unwrap_offloaded_literal(ctx, v)
257+
v = loop_manager.run_sync(TypeEngine.unwrap_offloaded_literal, ctx, v)
257258
if k not in self.bound_inputs:
258259
# assert that v.collection is not None
259260
if not v.collection or not isinstance(v.collection.literals, list):

Diff for: flytekit/core/base_task.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from flytekit.models.documentation import Description, Documentation
8787
from flytekit.models.interface import Variable
8888
from flytekit.models.security import SecurityContext
89+
from flytekit.utils.asyn import run_sync
8990

9091
DYNAMIC_PARTITIONS = "_uap"
9192
MODEL_CARD = "_ucm"
@@ -608,7 +609,7 @@ def _literal_map_to_python_input(
608609
) -> Dict[str, Any]:
609610
return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs)
610611

611-
def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
612+
async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
612613
expected_output_names = list(self._outputs_interface.keys())
613614
if len(expected_output_names) == 1:
614615
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
@@ -629,27 +630,35 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
629630
with timeit("Translate the output to literals"):
630631
literals = {}
631632
omt = ctx.output_metadata_tracker
633+
# Here is where we iterate through the outputs, need to call new type engine.
632634
for i, (k, v) in enumerate(native_outputs_as_map.items()):
633635
literal_type = self._outputs_interface[k].type
634636
py_type = self.get_type_for_output_var(k, v)
635637

636638
if isinstance(v, tuple):
637639
raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}")
638-
try:
639-
lit = TypeEngine.to_literal(ctx, v, py_type, literal_type)
640-
literals[k] = lit
641-
except Exception as e:
640+
literals[k] = asyncio.create_task(TypeEngine.async_to_literal(ctx, v, py_type, literal_type))
641+
642+
await asyncio.gather(*literals.values(), return_exceptions=True)
643+
644+
for i, (k2, v2) in enumerate(literals.items()):
645+
if v2.exception() is not None:
642646
# only show the name of output key if it's user-defined (by default Flyte names these as "o<n>")
643-
key = k if k != f"o{i}" else i
647+
key = k2 if k2 != f"o{i}" else i
648+
e: BaseException = v2.exception() # type: ignore # we know this is not optional
649+
py_type = self.get_type_for_output_var(k2, native_outputs_as_map[k2])
644650
e.args = (
645651
f"Failed to convert outputs of task '{self.name}' at position {key}.\n"
646652
f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n"
647653
f"Error Message: {e.args[0]}.",
648654
)
649-
raise
650-
# Now check if there is any output metadata associated with this output variable and attach it to the
651-
# literal
652-
if omt is not None:
655+
raise e
656+
literals[k2] = v2.result()
657+
658+
if omt is not None:
659+
for i, (k, v) in enumerate(native_outputs_as_map.items()):
660+
# Now check if there is any output metadata associated with this output variable and attach it to the
661+
# literal
653662
om = omt.get(v)
654663
if om:
655664
metadata = {}
@@ -669,7 +678,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
669678
encoded = b64encode(s).decode("utf-8")
670679
metadata[DYNAMIC_PARTITIONS] = encoded
671680
if metadata:
672-
lit.set_metadata(metadata)
681+
literals[k].set_metadata(metadata) # type: ignore # we know these have been resolved
673682

674683
return _literal_models.LiteralMap(literals=literals), native_outputs_as_map
675684

@@ -697,7 +706,7 @@ def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_param
697706
async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params):
698707
native_outputs = await native_outputs
699708
native_outputs = self.post_execute(new_user_params, native_outputs)
700-
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
709+
literals_map, native_outputs_as_map = await self._output_to_literal_map(native_outputs, exec_ctx)
701710
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
702711
return literals_map
703712

@@ -787,7 +796,10 @@ def dispatch_execute(
787796
return native_outputs
788797

789798
try:
790-
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
799+
with timeit("dispatch execute"):
800+
literals_map, native_outputs_as_map = run_sync(
801+
self._output_to_literal_map, native_outputs, exec_ctx
802+
)
791803
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
792804
except (FlyteUploadDataException, FlyteDownloadDataException):
793805
raise

Diff for: flytekit/core/promise.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@
4444
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
4545
from flytekit.models.task import Resources
4646
from flytekit.models.types import SimpleType
47+
from flytekit.utils.asyn import loop_manager, run_sync
4748

4849

49-
def translate_inputs_to_literals(
50+
async def _translate_inputs_to_literals(
5051
ctx: FlyteContext,
5152
incoming_values: Dict[str, Any],
5253
flyte_interface_types: Dict[str, _interface_models.Variable],
@@ -94,16 +95,19 @@ def my_wf(in1: int, in2: int) -> int:
9495
t = native_types[k]
9596
try:
9697
if type(v) is Promise:
97-
v = resolve_attr_path_in_promise(v)
98-
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
98+
v = await resolve_attr_path_in_promise(v)
99+
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type)
99100
except TypeTransformerFailedError as exc:
100101
exc.args = (f"Failed argument '{k}': {exc.args[0]}",)
101102
raise
102103

103104
return result
104105

105106

106-
def resolve_attr_path_in_promise(p: Promise) -> Promise:
107+
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)
108+
109+
110+
async def resolve_attr_path_in_promise(p: Promise) -> Promise:
107111
"""
108112
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
109113
This is for local execution only. The remote execution will be resolved in flytepropeller.
@@ -145,7 +149,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
145149
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
146150
literal_type = TypeEngine.to_literal_type(type(new_st))
147151
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
148-
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
152+
curr_val = await TypeEngine.async_to_literal(
153+
FlyteContextManager.current_context(), new_st, type(new_st), literal_type
154+
)
149155
elif type(curr_val.value.value) is Binary:
150156
binary_idl_obj = curr_val.value.value
151157
if binary_idl_obj.tag == _common_constants.MESSAGEPACK:
@@ -786,7 +792,7 @@ def __rshift__(self, other: Any):
786792
return Output(*promises) # type: ignore
787793

788794

789-
def binding_data_from_python_std(
795+
async def binding_data_from_python_std(
790796
ctx: _flyte_context.FlyteContext,
791797
expected_literal_type: _type_models.LiteralType,
792798
t_value: Any,
@@ -821,7 +827,8 @@ def binding_data_from_python_std(
821827
# If the value is not a container type, then we can directly convert it to a scalar in the Union case.
822828
# This pushes the handling of the Union types to the type engine.
823829
if not isinstance(t_value, list) and not isinstance(t_value, dict):
824-
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
830+
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
831+
scalar = lit.scalar
825832
return _literals_models.BindingData(scalar=scalar)
826833

827834
# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
@@ -831,7 +838,7 @@ def binding_data_from_python_std(
831838
try:
832839
lt_type = expected_literal_type.union_type.variants[i]
833840
python_type = get_args(t_value_type)[i] if t_value_type else None
834-
return binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
841+
return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
835842
except Exception:
836843
logger.debug(
837844
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
@@ -844,7 +851,9 @@ def binding_data_from_python_std(
844851
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type)
845852
collection = _literals_models.BindingDataCollection(
846853
bindings=[
847-
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes)
854+
await binding_data_from_python_std(
855+
ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes
856+
)
848857
for t in t_value
849858
]
850859
)
@@ -860,13 +869,13 @@ def binding_data_from_python_std(
860869
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}"
861870
)
862871
if expected_literal_type.simple == _type_models.SimpleType.STRUCT:
863-
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
872+
lit = await TypeEngine.async_to_literal(ctx, t_value, type(t_value), expected_literal_type)
864873
return _literals_models.BindingData(scalar=lit.scalar)
865874
else:
866875
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type)
867876
m = _literals_models.BindingDataMap(
868877
bindings={
869-
k: binding_data_from_python_std(
878+
k: await binding_data_from_python_std(
870879
ctx, expected_literal_type.map_value_type, v, v_type or type(v), nodes
871880
)
872881
for k, v in t_value.items()
@@ -883,8 +892,8 @@ def binding_data_from_python_std(
883892
)
884893

885894
# This is the scalar case - e.g. my_task(in1=5)
886-
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
887-
return _literals_models.BindingData(scalar=scalar)
895+
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
896+
return _literals_models.BindingData(scalar=lit.scalar)
888897

889898

890899
def binding_from_python_std(
@@ -895,7 +904,8 @@ def binding_from_python_std(
895904
t_value_type: type,
896905
) -> Tuple[_literals_models.Binding, List[Node]]:
897906
nodes: List[Node] = []
898-
binding_data = binding_data_from_python_std(
907+
binding_data = run_sync(
908+
binding_data_from_python_std,
899909
ctx,
900910
expected_literal_type,
901911
t_value,

0 commit comments

Comments
 (0)