diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index 5b02a5e88a2829..a08f2fe7d95910 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -241,6 +241,9 @@ void MatmulGradKernel(const Context& dev_ctx, dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, dx); return; } + if (!transpose_x && transpose_y && y.dims().size() < 2) { + transpose_y = false; + } // get dims std::vector x_dims = common::vectorize(x.dims()); std::vector y_dims = common::vectorize(y.dims()); diff --git a/paddle/phi/kernels/xpu/matmul_grad_kernel.cc b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc index 33fcab2308ec41..349a72277e7aa0 100644 --- a/paddle/phi/kernels/xpu/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc @@ -37,6 +37,9 @@ void MatmulGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(dy); } + if (!transpose_x && transpose_y && y.dims().size() < 2) { + transpose_y = false; + } const XPUType* dout_ptr = reinterpret_cast(dout.data()); const XPUType* x_ptr = reinterpret_cast(x.data()); const XPUType* y_ptr = reinterpret_cast(y.data()); diff --git a/test/legacy_test/test_matmul_v2_op.py b/test/legacy_test/test_matmul_v2_op.py index b3c25f7cd7397a..2f0d9425000461 100644 --- a/test/legacy_test/test_matmul_v2_op.py +++ b/test/legacy_test/test_matmul_v2_op.py @@ -978,6 +978,50 @@ def init_input_output(self): self.y = np.random.random((1, 3, 3, 2)) +class TestMatMulOp_trans_y(TestMatMulV2Op): + # y is 1-D and trans_y is True + def config(self): + self.x_shape = (2, 100) + self.y_shape = (100,) + self.trans_x = False + self.trans_y = True + + def init_kernel_type(self): + self.dtype = "float32" if core.is_compiled_with_rocm() else "float64" + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "matmul_v2" + self.python_api = paddle.tensor.matmul + self.public_python_api = paddle.tensor.matmul + x = np.random.random(self.x_shape).astype(self.dtype) + y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + x = -0.1 + 0.2 * x + y = -0.1 + 0.2 * y + result = reference_matmul(x, y, self.trans_x, self.trans_y) + result = result.astype(self.dtype) + self.inputs = { + 'X': x, + 'Y': y, + } + self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output( + check_pir=True, + ) + + def test_check_grad(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + if __name__ == "__main__": paddle.enable_static() unittest.main()