diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 76ebae94982d..906620563450 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -34,6 +34,7 @@ #include "../op/tensor/binary.h" #include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" namespace tvm { namespace relax { @@ -49,7 +50,11 @@ std::tuple)>> CreateP auto pat_rhs_a = WildcardPattern(); auto pat_rhs_b = WildcardPattern(); - auto pat_rhs = IsOp("relax.add")(pat_rhs_a, pat_rhs_b); + auto pat_rhs_sum = IsOp("relax.add")(pat_rhs_a, pat_rhs_b); + + auto pat_rhs_permute_dims = IsOp("relax.permute_dims")(pat_rhs_sum); + + auto pat_rhs = pat_rhs_sum | pat_rhs_permute_dims; auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); @@ -72,6 +77,17 @@ std::tuple)>> CreateP return expr; } + if (matches.count(pat_rhs_permute_dims)) { + auto call_permute = Downcast(matches[pat_rhs_permute_dims]); + auto attrs = call_permute->attrs.as(); + ICHECK(attrs) << "Operator permute_dims should have PermuteDimsAttrs, " + << "but " << call_permute << " has attributes " << call_permute->attrs; + auto axes = attrs->axes; + + rhs_a = permute_dims(rhs_a, axes); + rhs_b = permute_dims(rhs_b, axes); + } + return add(matmul(lhs, rhs_a, DataType::Void()), matmul(lhs, rhs_b, DataType::Void())); }; diff --git a/tests/python/relax/test_transform_expand_matmul_of_sum.py b/tests/python/relax/test_transform_expand_matmul_of_sum.py index 67e59225c5ed..b380d1584229 100644 --- a/tests/python/relax/test_transform_expand_matmul_of_sum.py +++ b/tests/python/relax/test_transform_expand_matmul_of_sum.py @@ -123,5 +123,35 @@ def main( return out +class TestRHSPermuteDims(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([32, 16], "float32"), + B: R.Tensor([32, 16], "float32"), + ) -> R.Tensor([32], "float32"): + linear_weight = R.add(A, B) + matmul_weight = R.permute_dims(linear_weight) + out = R.matmul(x, matmul_weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([32, 16], "float32"), + B: R.Tensor([32, 16], "float32"), + ) -> R.Tensor([32], "float32"): + A_transpose = R.permute_dims(A) + lhs = R.matmul(x, A_transpose) + B_transpose = R.permute_dims(B) + rhs = R.matmul(x, B_transpose) + out = R.add(lhs, rhs) + return out + + if __name__ == "__main__": tvm.testing.main()