Skip to content

Commit 7dddaf3

Browse files
committed
Fix
1 parent f728f15 commit 7dddaf3

File tree

6 files changed

+82
-25
lines changed

6 files changed

+82
-25
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2992,7 +2992,12 @@ void MatmulInferMeta(const MetaTensor& x,
29922992
} else {
29932993
new_dims.reserve(ndims_x);
29942994
for (size_t i = 0; i < ndims_x - 2; ++i) {
2995-
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
2995+
// If one of them is 0, choose 0.
2996+
if (dims_x[i] == 0 || dims_y[i] == 0) {
2997+
new_dims.push_back(0);
2998+
} else {
2999+
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
3000+
}
29963001
}
29973002
}
29983003
if (!x_broadcasted) {

paddle/phi/kernels/impl/matmul_grad_kernel_impl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,15 @@ void MatmulGradKernel(const Context& dev_ctx,
230230
DenseTensor* dx,
231231
DenseTensor* dy) {
232232
if (x.numel() == 0) {
233-
if (dy != nullptr) {
234-
dev_ctx.template Alloc<T>(dx);
235-
phi::FullKernel<T>(
236-
dev_ctx, common::vectorize(y.dims()), 0.0, y.dtype(), dy);
237-
}
233+
dev_ctx.template Alloc<T>(dx);
234+
phi::Full<T, Context>(
235+
dev_ctx, phi::IntArray(common::vectorize(y.dims())), 0, dy);
236+
return;
237+
}
238+
if (y.numel() == 0) {
239+
dev_ctx.template Alloc<T>(dy);
240+
phi::Full<T, Context>(
241+
dev_ctx, phi::IntArray(common::vectorize(x.dims())), 0, dx);
238242
return;
239243
}
240244
// get dims

paddle/phi/kernels/impl/matmul_kernel_impl.h

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License. */
3838
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
3939
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
4040
#endif
41+
#include "paddle/phi/kernels/full_kernel.h"
4142

4243
COMMON_DECLARE_bool(cuda_core_int8_gemm);
4344

@@ -2007,23 +2008,9 @@ void MatmulKernel(const Context& dev_ctx,
20072008
bool transpose_y,
20082009
DenseTensor* out) {
20092010
if (x.numel() == 0 || y.numel() == 0) {
2010-
auto x_dims = x.dims();
2011-
auto y_dims = y.dims();
2012-
if (transpose_x) {
2013-
std::swap(x_dims[x_dims.size() - 1], x_dims[x_dims.size() - 2]);
2014-
}
2015-
if (transpose_y) {
2016-
std::swap(y_dims[y_dims.size() - 1], y_dims[y_dims.size() - 2]);
2017-
}
2018-
std::vector<std::int64_t> out_dims(x_dims.size() - 1 + y_dims.size() - 1);
2019-
for (int64_t i = 0; i < x_dims.size() - 1; ++i) {
2020-
out_dims[i] = x_dims[i];
2021-
}
2022-
for (int64_t i = 1; i < y_dims.size(); ++i) {
2023-
out_dims[x_dims.size() - 1 + i - 1] = y_dims[i];
2024-
}
2025-
out->Resize(phi::make_ddim(out_dims));
2026-
dev_ctx.template Alloc<T>(out);
2011+
// input shape [1, 1, 5, 0], [1, 1, 0, 5], result shape is [1, 1, 5, 5]
2012+
phi::Full<T, Context>(
2013+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
20272014
return;
20282015
}
20292016
PADDLE_ENFORCE_GE(

python/paddle/tensor/math.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2942,8 +2942,16 @@ def outer(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
29422942
29432943
29442944
"""
2945-
nx = x.reshape((-1, 1))
2946-
ny = y.reshape((1, -1))
2945+
xshape = x.shape
2946+
yshape = y.shape
2947+
if math.prod(xshape) == 0: # If the size is 0
2948+
nx = x.reshape((0, 0))
2949+
else:
2950+
nx = x.reshape((-1, 1))
2951+
if math.prod(yshape) == 0: # If the size is 0
2952+
ny = y.reshape((0, 0))
2953+
else:
2954+
ny = y.reshape((1, -1))
29472955

29482956
if in_dynamic_mode():
29492957
return _C_ops.matmul(nx, ny, False, False)

test/legacy_test/test_matmul_v2_op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,41 @@ def func_dygraph_matmul(self):
943943
paddle.enable_static()
944944

945945

946+
class TestMatMulOp_ZeroSize(OpTest):
947+
def setUp(self):
948+
self.op_type = "matmul_v2"
949+
self.python_api = paddle.matmul
950+
self.init_input_output()
951+
952+
self.inputs = {
953+
'X': OpTest.np_dtype_to_base_dtype(self.x),
954+
'Y': OpTest.np_dtype_to_base_dtype(self.y),
955+
}
956+
self.out = np.matmul(self.x, self.y)
957+
self.attrs = {'axis': -1, 'use_mkldnn': False}
958+
self.outputs = {'Out': self.out}
959+
960+
def init_input_output(self):
961+
self.x = np.random.random((1, 1, 2, 3))
962+
self.y = np.random.random((1, 0, 3, 2))
963+
964+
def test_check_output(self):
965+
self.check_output(check_pir=True)
966+
967+
def test_check_grad(self):
968+
self.check_grad(
969+
['X', 'Y'],
970+
'Out',
971+
check_pir=True,
972+
)
973+
974+
975+
class TestMatMulOp_ZeroSize2(TestMatMulOp_ZeroSize):
976+
def init_input_output(self):
977+
self.x = np.random.random((0, 3, 2, 3))
978+
self.y = np.random.random((1, 3, 3, 2))
979+
980+
946981
if __name__ == "__main__":
947982
paddle.enable_static()
948983
unittest.main()

test/legacy_test/test_outer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,23 @@ def test_errors_dynamic(self):
172172
self.assertRaises(Exception, paddle.outer, x_data, y_data)
173173

174174

175+
class TestMultiplyApi_ZeroSize(unittest.TestCase):
176+
def test_multiply_dynamic(self):
177+
x_data = np.random.rand(5, 10, 0).astype(np.float64)
178+
y_data = np.random.rand(0, 10).astype(np.float64)
179+
paddle.disable_static()
180+
x = paddle.to_tensor(x_data)
181+
y = paddle.to_tensor(y_data)
182+
x.stop_gradient = False
183+
y.stop_gradient = False
184+
res = paddle.outer(x, y)
185+
np.testing.assert_allclose(
186+
res.numpy(), np.outer(x_data, y_data), rtol=1e-05
187+
)
188+
loss = paddle.sum(res)
189+
loss.backward()
190+
np.testing.assert_allclose(x.grad.shape, x.shape)
191+
192+
175193
if __name__ == '__main__':
176194
unittest.main()

0 commit comments

Comments
 (0)