diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index b5f4bcde0..ddb42a31d 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -10,6 +10,7 @@ from onnxscript.rewriter import ( broadcast_to_matmul, cast_constant_of_shape, + collapse_slices, gemm_to_matmul_add, no_op, ) @@ -21,6 +22,7 @@ *broadcast_to_matmul.rules.rules, gemm_to_matmul_add.rule, *cast_constant_of_shape.rules.rules, + *collapse_slices.rules.rules, ] diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py new file mode 100644 index 000000000..57d9baf28 --- /dev/null +++ b/onnxscript/rewriter/collapse_slices.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import logging + +from onnxscript import ir +from onnxscript.rewriter import pattern + +logger = logging.getLogger(__name__) +_INT64_MAX = 9223372036854775807 + + +def _check_if_redundant_slice( + context, + data: ir.Value, + starts: ir.Value, + ends: ir.Value, + axes: ir.Value, + steps: ir.Value, + **_, +) -> bool: + """If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant.""" + del context # Reserved for future extensions + + starts_const = starts.const_value + ends_const = ends.const_value + axes_const = axes.const_value + steps_const = steps.const_value + + # Check if the values are scalar + if starts_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'start' is not a scalar.") + return False + if ends_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'end' is not a scalar.") + return False + if axes_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'axis' is not a scalar.") + return False + if steps_const.numpy().size != 1: # type: ignore[union-attr] + logger.info("The value 'step' is not a scalar.") + return False + + if starts_const is None or ends_const is None or axes_const is None or steps_const is None: + logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.") + return False + if steps_const.numpy().item() != 1: + logger.info("The value 'step' is not 1.") + return False + # starts is 0 + if starts_const.numpy().item() != 0: + logger.info("The value 'start' is not 0.") + return False + # In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize + if ends_const.numpy().item() == _INT64_MAX: + return True + if data.shape is None: + logger.info("The value 'data' shape is not statically known.") + return False + if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]: + logger.info("The value 'end' is less than the shape of the specified axis.") + return False + + return True + + +def _identity_to_itself(op, data, **_): + """Return the input data as the output.""" + return op.Identity(data) + + +def _identity_to_updates(op, data, indices, updates, **_): + """Return the updates as the output. + + This is used when the ScatterND is redundant in terms of + updating the whole data with the updates. + + """ + return op.Identity(updates) + + +def _potential_redundant_slice(op, data, starts, ends, axes, steps): + """To identify a slice op""" + return op.Slice(data, starts, ends, axes, steps) + + +def _potential_redundant_scatternd(op, data, indices, updates): + """To identify a ScatterND op""" + return op.ScatterND(data, indices, updates) + + +def _check_if_redundant_scatternd( + context, + data: ir.Value, + indices: ir.Value, + updates: ir.Value, + **_, +): + """If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value.""" + del context # Reserved for future extensions + + # To validate data can be replaced directly by updates, we need to check the following: + # 1. they have the same shape + if data.shape is None: + logger.info("The value 'data' shape is not statically known.") + return False + if updates.shape is None: + logger.info("The value 'updates' shape is not statically known.") + return False + if data.shape != updates.shape: + logger.info("The shape of 'data' and 'updates' are different.") + return False + + # 2. the indices is referring to the whole data, which is from 0 to data.shape[0] + if indices.const_value is None: + logger.info("The value 'indices' is not statically known.") + return False + if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type] + logger.info("The 'indices' is not referring to the whole data.") + return False + + return True + + +# Register the rewrite rules +remove_redundant_slice = pattern.RewriteRule( + _potential_redundant_slice, + _identity_to_itself, + _check_if_redundant_slice, +) + +remove_redundant_scatternd = pattern.RewriteRule( + _potential_redundant_scatternd, + _identity_to_updates, + _check_if_redundant_scatternd, +) + +# NOTE: The order of the rules is important. Larger pattern should be checked first. +rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd]) diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/collapse_slices_test.py new file mode 100644 index 000000000..22537934b --- /dev/null +++ b/onnxscript/rewriter/collapse_slices_test.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnx.parser +import onnx.shape_inference + +from onnxscript import ir +from onnxscript.rewriter import collapse_slices, testing + +_INT64_MAX = 9223372036854775807 + + +class TwoReshapesMatMulReshapeTest(unittest.TestCase): + def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) + { + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 5) + self.assertIn("Identity", [node.op_type for node in model.graph]) + testing.assert_numerically_equal( + model_proto, + model, + (np.random.rand(512, 16, 112).astype(np.float32),), + ) + + def test_slice_is_redundant_when_ends_reaches_int64_max(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 5) + self.assertIn("Identity", [node.op_type for node in model.graph]) + testing.assert_numerically_equal( + model_proto, + model, + (np.random.rand(512, 16, 112).astype(np.float32),), + ) + + def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output) + { + output = ScatterND (data, indices, updates) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + indices = np.arange(112).reshape(112, 1) + model = ir.serde.deserialize_model(model_proto) + # from numpy to ir.Tensor + indices_ir_tensor = ir.Tensor( + name="indices", + value=indices, + ) + # assign the tensor to a value + indices = model.graph[0].inputs[1] + indices.const_value = indices_ir_tensor + model.graph.initializers["indices"] = indices + original_model_proto = ir.serde.serialize_model(model) + + count = collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 1) + self.assertIn("Identity", [node.op_type for node in model.graph]) + + input = np.random.rand(112, 16, 512).astype(np.float32) + testing.assert_numerically_equal(original_model_proto, model, (input, input)) diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 21cee515d..6d25b0ed3 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -32,7 +32,7 @@ def dropout_inference(op, x): # Replacement -def identity(op, x): +def identity(op, x, **_): return op.Identity(x) diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py new file mode 100644 index 000000000..95b815515 --- /dev/null +++ b/onnxscript/rewriter/testing.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any + +import numpy as np +import onnx +import onnxruntime as ort + +from onnxscript import ir + + +def assert_numerically_equal( + original_model_proto: onnx.ModelProto | ir.Model, + rewritten_model_proto: onnx.ModelProto | ir.Model, + args: tuple[Any, ...], + rtol: float = 1, + atol: float = 1e-3, +): + """Assert that the two models are numerically equal. + + Args: + original_model_proto: The original model proto or ir.Model. + rewritten_model_proto: The rewritten by the rules model proto or ir.Model. + rtol: Relative tolerance. + atol: Absolute tolerance. + args: The positional arguments to pass to the model. + """ + + if isinstance(original_model_proto, ir.Model): + original_model_proto = ir.serde.serialize_model(original_model_proto) + if isinstance(rewritten_model_proto, ir.Model): + rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) + + original_proto_ort_inputs = { + k.name: v for k, v in zip(original_model_proto.graph.input, args) + } + original_proto_ort_inference_session = _ort_session_initializer( + original_model_proto.SerializeToString() + ) + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + original_outputs = original_proto_ort_inference_session.run( + None, original_proto_ort_inputs, run_options=run_options + ) + + the_rewritten_proto_ort_inputs = { + k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) + } + the_rewritten_proto_ort_inference_session = _ort_session_initializer( + rewritten_model_proto.SerializeToString() + ) + the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( + None, the_rewritten_proto_ort_inputs, run_options=run_options + ) + + np.testing.assert_allclose( + original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True + ) + + +def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: + """Initialize an ONNX Runtime inference session with the specified model.""" + import onnxruntime as ort + + session_options = ort.SessionOptions() + session_options.log_severity_level = 3 # 3: Error + possible_providers = ( + "CUDAExecutionProvider", + "CPUExecutionProvider", + ) + available_providers = set(ort.get_available_providers()) + providers = [ + provider for provider in possible_providers if provider in available_providers + ] + return ort.InferenceSession(model, providers=providers, sess_options=session_options)