@@ -35,20 +35,36 @@ def test_pass_with_clear_metadata_and_docstring(self):
3535 metadata_props = {"mul_key" : "mul_value" },
3636 doc_string = "This is a Mul node" ,
3737 )
38- func_inputs = [
39- ir .Value (
40- name = "input_a" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
41- ),
42- ir .Value (
43- name = "input_b" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
44- ),
45- ]
4638 function = ir .Function (
4739 graph = ir .Graph (
4840 name = "my_function" ,
49- inputs = func_inputs ,
50- outputs = mul_node .outputs ,
51- nodes = [add_node , mul_node ],
41+ inputs = [
42+ input_a := ir .Value (
43+ name = "input_a" ,
44+ type = ir .TensorType (ir .DataType .FLOAT ),
45+ shape = ir .Shape ((2 , 3 )),
46+ ),
47+ input_b := ir .Value (
48+ name = "input_b" ,
49+ type = ir .TensorType (ir .DataType .FLOAT ),
50+ shape = ir .Shape ((2 , 3 )),
51+ ),
52+ ],
53+ nodes = [
54+ add_node_func := ir .node (
55+ "Add" ,
56+ inputs = [input_a , input_b ],
57+ metadata_props = {"add_key" : "add_value" },
58+ doc_string = "This is an Add node" ,
59+ ),
60+ mul_node_func := ir .node (
61+ "Mul" ,
62+ inputs = [add_node_func .o (), input_b ],
63+ metadata_props = {"mul_key" : "mul_value" },
64+ doc_string = "This is a Mul node" ,
65+ ),
66+ ],
67+ outputs = mul_node_func .outputs ,
5268 opset_imports = {"" : 20 },
5369 doc_string = "This is a function docstring" ,
5470 metadata_props = {"function_key" : "function_value" },
@@ -59,8 +75,8 @@ def test_pass_with_clear_metadata_and_docstring(self):
5975 )
6076 func_node = ir .node (
6177 "my_function" ,
62- inputs = [add_node .o (), inputs [ 1 ] ],
63- domain = "my_domain" ,
78+ inputs = [inputs [ 0 ], mul_node .o ()],
79+ domain = "my_domain" ,
6480 metadata_props = {"mul_key" : "mul_value" },
6581 doc_string = "This is a Mul node" ,
6682 )
@@ -77,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self):
7793 )
7894 sub_node = ir .node (
7995 "Sub" ,
80- inputs = [function .o (), const_node .o ()],
96+ inputs = [func_node .o (), const_node .o ()],
8197 num_outputs = 1 ,
8298 metadata_props = {"sub_key" : "sub_value" },
8399 doc_string = "This is a Sub node" ,
0 commit comments