44
44
from flytekit .models .literals import Binary , Literal , Primitive , Scalar
45
45
from flytekit .models .task import Resources
46
46
from flytekit .models .types import SimpleType
47
+ from flytekit .utils .asyn import loop_manager , run_sync
47
48
48
49
49
- def translate_inputs_to_literals (
50
+ async def _translate_inputs_to_literals (
50
51
ctx : FlyteContext ,
51
52
incoming_values : Dict [str , Any ],
52
53
flyte_interface_types : Dict [str , _interface_models .Variable ],
@@ -94,16 +95,19 @@ def my_wf(in1: int, in2: int) -> int:
94
95
t = native_types [k ]
95
96
try :
96
97
if type (v ) is Promise :
97
- v = resolve_attr_path_in_promise (v )
98
- result [k ] = TypeEngine .to_literal (ctx , v , t , var .type )
98
+ v = await resolve_attr_path_in_promise (v )
99
+ result [k ] = await TypeEngine .async_to_literal (ctx , v , t , var .type )
99
100
except TypeTransformerFailedError as exc :
100
101
exc .args = (f"Failed argument '{ k } ': { exc .args [0 ]} " ,)
101
102
raise
102
103
103
104
return result
104
105
105
106
106
- def resolve_attr_path_in_promise (p : Promise ) -> Promise :
107
+ translate_inputs_to_literals = loop_manager .synced (_translate_inputs_to_literals )
108
+
109
+
110
+ async def resolve_attr_path_in_promise (p : Promise ) -> Promise :
107
111
"""
108
112
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
109
113
This is for local execution only. The remote execution will be resolved in flytepropeller.
@@ -145,7 +149,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
145
149
new_st = resolve_attr_path_in_pb_struct (st , attr_path = p .attr_path [used :])
146
150
literal_type = TypeEngine .to_literal_type (type (new_st ))
147
151
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
148
- curr_val = TypeEngine .to_literal (FlyteContextManager .current_context (), new_st , type (new_st ), literal_type )
152
+ curr_val = await TypeEngine .async_to_literal (
153
+ FlyteContextManager .current_context (), new_st , type (new_st ), literal_type
154
+ )
149
155
elif type (curr_val .value .value ) is Binary :
150
156
binary_idl_obj = curr_val .value .value
151
157
if binary_idl_obj .tag == _common_constants .MESSAGEPACK :
@@ -786,7 +792,7 @@ def __rshift__(self, other: Any):
786
792
return Output (* promises ) # type: ignore
787
793
788
794
789
- def binding_data_from_python_std (
795
+ async def binding_data_from_python_std (
790
796
ctx : _flyte_context .FlyteContext ,
791
797
expected_literal_type : _type_models .LiteralType ,
792
798
t_value : Any ,
@@ -821,7 +827,8 @@ def binding_data_from_python_std(
821
827
# If the value is not a container type, then we can directly convert it to a scalar in the Union case.
822
828
# This pushes the handling of the Union types to the type engine.
823
829
if not isinstance (t_value , list ) and not isinstance (t_value , dict ):
824
- scalar = TypeEngine .to_literal (ctx , t_value , t_value_type or type (t_value ), expected_literal_type ).scalar
830
+ lit = await TypeEngine .async_to_literal (ctx , t_value , t_value_type or type (t_value ), expected_literal_type )
831
+ scalar = lit .scalar
825
832
return _literals_models .BindingData (scalar = scalar )
826
833
827
834
# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
@@ -831,7 +838,7 @@ def binding_data_from_python_std(
831
838
try :
832
839
lt_type = expected_literal_type .union_type .variants [i ]
833
840
python_type = get_args (t_value_type )[i ] if t_value_type else None
834
- return binding_data_from_python_std (ctx , lt_type , t_value , python_type , nodes )
841
+ return await binding_data_from_python_std (ctx , lt_type , t_value , python_type , nodes )
835
842
except Exception :
836
843
logger .debug (
837
844
f"failed to bind data { t_value } with literal type { expected_literal_type .union_type .variants [i ]} ."
@@ -844,7 +851,9 @@ def binding_data_from_python_std(
844
851
sub_type : Optional [type ] = ListTransformer .get_sub_type_or_none (t_value_type )
845
852
collection = _literals_models .BindingDataCollection (
846
853
bindings = [
847
- binding_data_from_python_std (ctx , expected_literal_type .collection_type , t , sub_type or type (t ), nodes )
854
+ await binding_data_from_python_std (
855
+ ctx , expected_literal_type .collection_type , t , sub_type or type (t ), nodes
856
+ )
848
857
for t in t_value
849
858
]
850
859
)
@@ -860,13 +869,13 @@ def binding_data_from_python_std(
860
869
f"this should be a Dictionary type and it is not: { type (t_value )} vs { expected_literal_type } "
861
870
)
862
871
if expected_literal_type .simple == _type_models .SimpleType .STRUCT :
863
- lit = TypeEngine .to_literal (ctx , t_value , type (t_value ), expected_literal_type )
872
+ lit = await TypeEngine .async_to_literal (ctx , t_value , type (t_value ), expected_literal_type )
864
873
return _literals_models .BindingData (scalar = lit .scalar )
865
874
else :
866
875
_ , v_type = DictTransformer .extract_types_or_metadata (t_value_type )
867
876
m = _literals_models .BindingDataMap (
868
877
bindings = {
869
- k : binding_data_from_python_std (
878
+ k : await binding_data_from_python_std (
870
879
ctx , expected_literal_type .map_value_type , v , v_type or type (v ), nodes
871
880
)
872
881
for k , v in t_value .items ()
@@ -883,8 +892,8 @@ def binding_data_from_python_std(
883
892
)
884
893
885
894
# This is the scalar case - e.g. my_task(in1=5)
886
- scalar = TypeEngine .to_literal (ctx , t_value , t_value_type or type (t_value ), expected_literal_type ). scalar
887
- return _literals_models .BindingData (scalar = scalar )
895
+ lit = await TypeEngine .async_to_literal (ctx , t_value , t_value_type or type (t_value ), expected_literal_type )
896
+ return _literals_models .BindingData (scalar = lit . scalar )
888
897
889
898
890
899
def binding_from_python_std (
@@ -895,7 +904,8 @@ def binding_from_python_std(
895
904
t_value_type : type ,
896
905
) -> Tuple [_literals_models .Binding , List [Node ]]:
897
906
nodes : List [Node ] = []
898
- binding_data = binding_data_from_python_std (
907
+ binding_data = run_sync (
908
+ binding_data_from_python_std ,
899
909
ctx ,
900
910
expected_literal_type ,
901
911
t_value ,
0 commit comments