Skip to content

Commit bdba7c6

Browse files
authored
Fix (#73108)
1 parent 6d3db1f commit bdba7c6

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

paddle/phi/kernels/impl/bmm_grad_kernel_impl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ void BmmGradKernel(const Context& dev_ctx,
5454
const DenseTensor& out_grad,
5555
DenseTensor* x_grad,
5656
DenseTensor* y_grad) {
57+
if (x_grad && x_grad->numel() == 0) {
58+
dev_ctx.template Alloc<T>(x_grad);
59+
phi::Full<T, Context>(
60+
dev_ctx, phi::IntArray(common::vectorize(y.dims())), 0, y_grad);
61+
return;
62+
}
63+
if (y_grad && y_grad->numel() == 0) {
64+
dev_ctx.template Alloc<T>(y_grad);
65+
phi::Full<T, Context>(
66+
dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, x_grad);
67+
return;
68+
}
5769
DenseTensor x_help = x;
5870
DenseTensor y_help = y;
5971
DenseTensor out_grad_help = out_grad;

paddle/phi/kernels/xpu/bmm_grad_kernel.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/phi/kernels/bmm_grad_kernel.h"
1616

17+
#include "paddle/phi/kernels/full_kernel.h"
1718
#include "paddle/phi/kernels/xpu/bmm_xpu_utils.h"
1819

1920
namespace phi {
@@ -60,6 +61,18 @@ void BmmGradKernel(const Context& dev_ctx,
6061
const DenseTensor& out_grad,
6162
DenseTensor* x_grad,
6263
DenseTensor* y_grad) {
64+
if (x_grad && x_grad->numel() == 0) {
65+
dev_ctx.template Alloc<T>(x_grad);
66+
phi::Full<T, Context>(
67+
dev_ctx, phi::IntArray(common::vectorize(y.dims())), 0, y_grad);
68+
return;
69+
}
70+
if (y_grad && y_grad->numel() == 0) {
71+
dev_ctx.template Alloc<T>(y_grad);
72+
phi::Full<T, Context>(
73+
dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, x_grad);
74+
return;
75+
}
6376
DenseTensor x_help = x;
6477
DenseTensor y_help = y;
6578
DenseTensor out_grad_help = out_grad;

test/legacy_test/test_bmm_op.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,23 @@ def test_api_error(self):
156156
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)
157157

158158

159+
class TestBmmOp_ZeroSize(OpTest):
160+
def setUp(self):
161+
self.op_type = "bmm"
162+
self.python_api = paddle.bmm
163+
self.public_python_api = paddle.bmm
164+
X = np.random.random((10, 0, 4)).astype("float64")
165+
Y = np.random.random((10, 4, 5)).astype("float64")
166+
self.inputs = {'X': X, 'Y': Y}
167+
Out = np.matmul(X, Y)
168+
self.outputs = {'Out': Out}
169+
170+
def test_check_output(self):
171+
self.check_output(check_pir=True)
172+
173+
def test_checkout_grad(self):
174+
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
175+
176+
159177
if __name__ == "__main__":
160178
unittest.main()

0 commit comments

Comments
 (0)