Skip to content

Commit 8e557e2

Browse files
committed
Fix
1 parent 5172626 commit 8e557e2

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

paddle/phi/kernels/impl/multi_dot_kernel_impl.h

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
16-
17-
Licensed under the Apache License, Version 2.0 (the "License");
18-
you may not use this file except in compliance with the License.
19-
You may obtain a copy of the License at
20-
21-
http://www.apache.org/licenses/LICENSE-2.0
22-
23-
Unless required by applicable law or agreed to in writing, software
24-
distributed under the License is distributed on an "AS IS" BASIS,
25-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26-
See the License for the specific language governing permissions and
27-
limitations under the License. */
28-
2915
#pragma once
3016

3117
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
3219
#include "paddle/phi/kernels/funcs/blas/blas.h"
33-
3420
namespace phi {
3521

3622
template <typename Context, typename T>
@@ -195,6 +181,19 @@ void MultiDotKernel(const Context& ctx,
195181
std::vector<phi::DDim> ins_dims(n);
196182
GetDims<Context, T>(ins, &ins_dims);
197183

184+
// If any numel is 0, then return.
185+
bool size_0 = false;
186+
for (size_t i = 0; i < n; i++) {
187+
if (x[i]->numel() == 0) size_0 = true;
188+
}
189+
if (size_0) {
190+
// For example: [2, 0], [0, 4] -> [2, 4]
191+
if (out && out->numel() > 0) {
192+
phi::Full<T, Context>(
193+
ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
194+
}
195+
return;
196+
}
198197
const T scale = static_cast<T>(1.0);
199198
if (n == 2) {
200199
auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(ins_dims[0], 0, false);
@@ -347,9 +346,23 @@ void MultiDotGradKernel(const Context& ctx,
347346

348347
auto blas = phi::funcs::GetBlas<Context, T>(ctx);
349348

349+
bool size_0 = false;
350350
const auto n = ins.size();
351351
for (size_t i = 0; i < n; i++) {
352352
ctx.template Alloc<T>(dx[i]);
353+
354+
if (dx[i]->numel() == 0) {
355+
size_0 = true;
356+
}
357+
}
358+
if (size_0) {
359+
for (size_t i = 0; i < n; i++) {
360+
if (dx[i]->numel() > 0) {
361+
phi::Full<T, Context>(
362+
ctx, phi::IntArray(common::vectorize(dx[i]->dims())), 0, dx[i]);
363+
}
364+
}
365+
return;
353366
}
354367

355368
std::vector<phi::DDim> ins_dims(n);

test/legacy_test/test_multi_dot_op.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,19 @@ def setUp(self):
3131
self.op_type = "multi_dot"
3232
self.python_api = paddle.linalg.multi_dot
3333
self.dtype = self.get_dtype()
34+
self.init_shape()
3435
self.get_inputs_and_outputs()
3536

37+
def init_shape(self):
38+
self.A_shape = (2, 8)
39+
self.B_shape = (8, 4)
40+
3641
def get_dtype(self):
3742
return "float64"
3843

3944
def get_inputs_and_outputs(self):
40-
self.A = np.random.random((2, 8)).astype(self.dtype)
41-
self.B = np.random.random((8, 4)).astype(self.dtype)
45+
self.A = np.random.random(self.A_shape).astype(self.dtype)
46+
self.B = np.random.random(self.B_shape).astype(self.dtype)
4247
self.inputs = {'X': [('x0', self.A), ('x1', self.B)]}
4348
self.outputs = {'Out': multi_dot([self.A, self.B])}
4449

@@ -55,6 +60,36 @@ def get_dtype(self):
5560
return "float16"
5661

5762

63+
class TestMultiDotOp_ZeroSize1(TestMultiDotOp):
64+
def get_inputs_and_outputs(self):
65+
# result shape: [2, 3]
66+
self.A = np.random.random((2, 10)).astype(self.dtype)
67+
self.B = np.random.random((10, 0)).astype(self.dtype)
68+
self.C = np.random.random((0, 3)).astype(self.dtype)
69+
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
70+
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
71+
72+
def test_check_grad(self):
73+
self.check_grad(['x0'], 'Out', check_pir=True)
74+
self.check_grad(['x1'], 'Out', check_pir=True)
75+
self.check_grad(['x2'], 'Out', check_pir=True)
76+
77+
78+
class TestMultiDotOp_ZeroSize2(TestMultiDotOp):
79+
def get_inputs_and_outputs(self):
80+
# result shape: [0, 3]
81+
self.A = np.random.random((0, 10)).astype(self.dtype)
82+
self.B = np.random.random((10, 4)).astype(self.dtype)
83+
self.C = np.random.random((4, 3)).astype(self.dtype)
84+
self.inputs = {'X': [('x0', self.A), ('x1', self.B), ('x2', self.C)]}
85+
self.outputs = {'Out': multi_dot([self.A, self.B, self.C])}
86+
87+
def test_check_grad(self):
88+
self.check_grad(['x0'], 'Out', check_pir=True)
89+
self.check_grad(['x1'], 'Out', check_pir=True)
90+
self.check_grad(['x2'], 'Out', check_pir=True)
91+
92+
5893
@unittest.skipIf(
5994
not core.is_compiled_with_cuda()
6095
or not core.is_bfloat16_supported(core.CUDAPlace(0)),

0 commit comments

Comments
 (0)