Skip to content

Commit 86de85f

Browse files
committed
Refactor IR value creation in tests to use ir.Value for consistency and improved clarity
1 parent a94f8b9 commit 86de85f

File tree

5 files changed

+56
-29
lines changed

5 files changed

+56
-29
lines changed

onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66
import onnx
77
import onnx.checker
88
import onnx.shape_inference
9+
import onnx_ir as ir
910
import onnxruntime
1011

11-
from onnxscript import ir
1212
from onnxscript.rewriter.onnxruntime.bfloat16_utils import bfloat16_converter
1313

1414

1515
class Bfloat16ConversionTest(unittest.TestCase):
1616
def setUp(self) -> None:
17-
self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4]))
18-
self.v0.dtype = ir.DataType.BFLOAT16
19-
self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4]))
20-
self.v1.dtype = ir.DataType.BFLOAT16
21-
self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4]))
22-
self.v2.dtype = ir.DataType.BFLOAT16
17+
self.v0 = ir.Value(
18+
name="v0", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16)
19+
)
20+
self.v1 = ir.Value(
21+
name="v1", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16)
22+
)
23+
self.v2 = ir.Value(
24+
name="v2", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.BFLOAT16)
25+
)
2326

2427
self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
2528
self.add_node.outputs[0].dtype = ir.DataType.BFLOAT16

onnxscript/rewriter/rules/common/_basic_rules_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import numpy as np
99
import onnx
1010
import onnx.reference
11+
import onnx_ir as ir
1112
import parameterized
1213

1314
import onnxscript
1415
import onnxscript.onnx_types as ot
15-
from onnxscript import ir
1616
from onnxscript.onnx_opset import opset18
1717
from onnxscript.rewriter import MatchingTracer, testing
1818
from onnxscript.rewriter import pattern as orp
@@ -421,14 +421,18 @@ def _convert_shape(shape, name):
421421
if isinstance(shape, np.ndarray):
422422
shape = tape.initializer(ir.Tensor(shape, name=name))
423423
elif isinstance(shape, (list, tuple)):
424-
shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape))
424+
shape = ir.Value(
425+
name=name, type=ir.TensorType(ir.DataType.INT64), shape=ir.Shape(shape)
426+
)
425427
tape.graph_like.inputs.append(shape)
426428
else:
427429
raise TypeError(f"Unsupported type {type(shape)} for shape.")
428430
return shape
429431

430-
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
431-
y = ir.val("Y", ir.DataType.FLOAT)
432+
x = ir.Value(
433+
name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape)
434+
)
435+
y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT))
432436
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
433437

434438
# Build the graph.
@@ -554,8 +558,10 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
554558
class Flatten2ReshapeTest(unittest.TestCase):
555559
@staticmethod
556560
def create_model(input_shape, axis=1):
557-
x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape))
558-
y = ir.val("Y", ir.DataType.FLOAT)
561+
x = ir.Value(
562+
name="X", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape(input_shape)
563+
)
564+
y = ir.Value(name="Y", type=ir.TensorType(ir.DataType.FLOAT))
559565
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
560566

561567
# Build the graph.

onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def build_model(
6161

6262
# Register operations in the tape
6363
idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT
64-
x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype))
64+
x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(idtype))
6565
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
6666
y = tape.op(
6767
op_type,
6868
inputs=[y, self.get_conv_weights(weight_shape, tape)],
6969
attributes=conv_attributes,
70-
output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
70+
output=ir.Value(name="Y", shape=output_shape, type=ir.TensorType(x.dtype)),
7171
)
7272
if op_type == "ConvInteger":
7373
y.dtype = ir.DataType.INT32
@@ -290,12 +290,12 @@ def build_model(
290290
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")
291291

292292
# Register operations in the tape
293-
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
293+
x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
294294
y = tape.op(
295295
"Conv",
296296
inputs=[x, *conv_inputs],
297297
attributes=conv_attributes,
298-
output=ir.val("Y", shape=output_shape, type=x.type),
298+
output=ir.Value(name="Y", shape=output_shape, type=x.type),
299299
)
300300

301301
# Build the model

onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import numpy as np
77
import onnx
8+
import onnx_ir as ir
89
from onnx_ir.passes.common import onnx_checker, shape_inference
910
from parameterized import parameterized
1011

11-
from onnxscript import ir
1212
from onnxscript.rewriter import MatchingTracer, MatchStatus, testing
1313
from onnxscript.rewriter.rules.common import _matmul_add_to_gemm
1414

@@ -46,10 +46,10 @@ def get_test_model(
4646
bias_shape = weight_shape[0] if transB else weight_shape[-1]
4747
output_shape = ir.Shape(("?",) * input_shape.rank())
4848

49-
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
49+
x = ir.Value(name="X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
5050

5151
if weight_as_inputs:
52-
w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
52+
w = ir.Value(name="W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
5353
inputs.append(w)
5454
else:
5555
w = ir.tensor(
@@ -58,8 +58,8 @@ def get_test_model(
5858
w = tape.initializer(w)
5959

6060
if bias_as_inputs:
61-
b = ir.val(
62-
"B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT)
61+
b = ir.Value(
62+
name="B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT)
6363
)
6464
inputs.append(b)
6565
else:
@@ -77,7 +77,9 @@ def get_test_model(
7777
y = tape.op(
7878
"Add",
7979
inputs=[y, b],
80-
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
80+
output=ir.Value(
81+
name="Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)
82+
),
8183
)
8284

8385
# Build the model

onnxscript/rewriter/rules/common/_remove_zero_bias_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,17 @@ def _apply_rule_and_check_optimization(
6161
assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost"
6262
original_value = attr_value.value
6363
new_value = target_node.attributes[attr_name].value
64-
assert new_value == original_value, (
65-
f"Attribute {attr_name} value changed from {original_value} to {new_value}"
66-
)
64+
# Convert both to same type for comparison to handle list vs tuple differences
65+
if isinstance(original_value, (list, tuple)) and isinstance(
66+
new_value, (list, tuple)
67+
):
68+
assert list(original_value) == list(new_value), (
69+
f"Attribute {attr_name} value changed from {original_value} to {new_value}"
70+
)
71+
else:
72+
assert new_value == original_value, (
73+
f"Attribute {attr_name} value changed from {original_value} to {new_value}"
74+
)
6775

6876
# Check specific expected attributes if provided
6977
if expected_attributes:
@@ -73,9 +81,17 @@ def _apply_rule_and_check_optimization(
7381
)
7482
actual_attr = target_node.attributes[attr_name]
7583
actual_value = actual_attr.value
76-
assert actual_value == expected_value, (
77-
f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}"
78-
)
84+
# Convert both to same type for comparison to handle list vs tuple differences
85+
if isinstance(actual_value, (list, tuple)) and isinstance(
86+
expected_value, (list, tuple)
87+
):
88+
assert list(actual_value) == list(expected_value), (
89+
f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}"
90+
)
91+
else:
92+
assert actual_value == expected_value, (
93+
f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}"
94+
)
7995

8096
# Compare outputs to ensure correctness (only for supported input types)
8197
if expected_count > 0:

0 commit comments

Comments
 (0)