Skip to content

Commit d3d7f87

Browse files
lhutton1pfk-beta
authored andcommitted
[AOT] Get input name from module/prim func (apache#10731)
The input name generated in each of these test cases changes depending on the version of tensorflow being used. v2.4 = "x_int8", while v2.6 = "x". Making these tests agnostic of input name so that they work with both v2.4 and v2.6. Change-Id: I843a655b3bf4e018624e5757c653b1d85058991e
1 parent dbd3834 commit d3d7f87

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/python/relay/aot/test_c_device_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True):
133133
def test_device_api_hooks_unpacked_api(device_api_main_func):
134134
"""Check for Device API hooks with unpacked internal calls"""
135135
main_func = device_api_main_func(interface_api="c", use_unpacked_api=True)
136+
input_name = main_func.params[0].name
136137

137138
# Activate Device
138139
assert (
@@ -153,7 +154,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
153154
str(main_func.body[1][0][0][1])
154155
== "tir.tvm_check_return(0, -1, tir.call_extern("
155156
+ '"tvmgen_default_ethos_u_main_0",'
156-
+ " x_int8_buffer_var, output_buffer_var, device_context_ethos_u))\n"
157+
+ f" {input_name}_buffer_var, output_buffer_var, device_context_ethos_u))\n"
157158
)
158159
# Close Device
159160
assert (

tests/python/relay/aot/test_crt_aot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ def representative_dataset():
905905

906906
in_min, in_max = (-128, 127)
907907
data = np.random.randint(in_min, high=in_max, size=ifm_shape, dtype="int8")
908-
inputs = {"x_int8": data}
908+
input_name = mod["main"].params[0].name_hint
909+
inputs = {input_name: data}
909910
output_list = generate_ref_data(mod, inputs, params)
910911
compile_and_run(
911912
AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params),

0 commit comments

Comments
 (0)