Skip to content

Commit 1ac850c

Browse files
committed
Fix
1 parent 98204ab commit 1ac850c

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

paddle/phi/kernels/cpu/multiplex_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void MultiplexKernel(const Context& dev_ctx,
2626
const DenseTensor& ids,
2727
DenseTensor* out) {
2828
dev_ctx.template Alloc<T>(out);
29+
if (out->numel() == 0) return;
2930
for (size_t i = 0; i < ins.size(); ++i) {
3031
PADDLE_ENFORCE_GT(
3132
ins[i]->numel(),

paddle/phi/kernels/gpu/multiplex_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void MultiplexKernel(const Context& dev_ctx,
2727
const DenseTensor& ids,
2828
DenseTensor* out) {
2929
dev_ctx.template Alloc<T>(out);
30+
if (out->numel() == 0) return;
3031
for (size_t i = 0; i < ins.size(); ++i) {
3132
PADDLE_ENFORCE_GT(
3233
ins[i]->numel(),

test/legacy_test/test_multiplex_op.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def init_dtype(self):
9797
class TestMultiplexOpError(unittest.TestCase):
9898

9999
def test_errors(self):
100+
paddle.enable_static()
100101
with base.program_guard(base.Program(), base.Program()):
101102
x1 = paddle.static.data(name='x1', shape=[None, 2], dtype='int64')
102103
x2 = paddle.static.data(name='x2', shape=[None, 2], dtype='int64')
@@ -198,5 +199,40 @@ def init_dtype(self):
198199
self.dtype = np.complex128
199200

200201

202+
class TestMultiplexOp_ZeroSize(OpTest):
203+
def setUp(self):
204+
self.op_type = "multiplex"
205+
self.init_dtype()
206+
self.python_api = paddle.tensor.multiplex
207+
rows = 4
208+
index = np.array([0, 2, 2, 3]).astype('int32')
209+
np.random.shuffle(index)
210+
index = np.reshape(index, (rows, 1))
211+
ins1 = np.random.random((rows, 0)).astype(self.dtype)
212+
ins2 = np.random.random((rows, 0)).astype(self.dtype)
213+
ins3 = np.random.random((rows, 0)).astype(self.dtype)
214+
ins4 = np.random.random((rows, 0)).astype(self.dtype)
215+
self.inputs = {
216+
'Ids': index,
217+
'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)],
218+
}
219+
# multiplex output
220+
output = np.zeros_like(ins1)
221+
for i in range(0, rows):
222+
k = index[i][0]
223+
if self.inputs['X'][k][1][i].size != 0:
224+
output[i] = self.inputs['X'][k][1][i]
225+
self.outputs = {'Out': output}
226+
227+
def init_dtype(self):
228+
self.dtype = 'float64'
229+
230+
def test_check_output(self):
231+
self.check_output(check_pir=True)
232+
233+
def test_check_grad(self):
234+
self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out', check_pir=True)
235+
236+
201237
if __name__ == '__main__':
202238
unittest.main()

0 commit comments

Comments
 (0)