diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc index 6e38a255a10513..f91879dd4569eb 100644 --- a/paddle/phi/kernels/cpu/multiplex_kernel.cc +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -26,6 +26,7 @@ void MultiplexKernel(const Context& dev_ctx, const DenseTensor& ids, DenseTensor* out) { dev_ctx.template Alloc(out); + if (out->numel() == 0) return; for (size_t i = 0; i < ins.size(); ++i) { PADDLE_ENFORCE_GT( ins[i]->numel(), diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu index 33fa3a74c527d0..b66cc4836bee90 100644 --- a/paddle/phi/kernels/gpu/multiplex_kernel.cu +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -27,6 +27,7 @@ void MultiplexKernel(const Context& dev_ctx, const DenseTensor& ids, DenseTensor* out) { dev_ctx.template Alloc(out); + if (out->numel() == 0) return; for (size_t i = 0; i < ins.size(); ++i) { PADDLE_ENFORCE_GT( ins[i]->numel(), diff --git a/test/legacy_test/test_multiplex_op.py b/test/legacy_test/test_multiplex_op.py index 67d3b0bbf73a02..0c69efeed97f7d 100644 --- a/test/legacy_test/test_multiplex_op.py +++ b/test/legacy_test/test_multiplex_op.py @@ -97,6 +97,7 @@ def init_dtype(self): class TestMultiplexOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with base.program_guard(base.Program(), base.Program()): x1 = paddle.static.data(name='x1', shape=[None, 2], dtype='int64') x2 = paddle.static.data(name='x2', shape=[None, 2], dtype='int64') @@ -198,5 +199,40 @@ def init_dtype(self): self.dtype = np.complex128 +class TestMultiplexOp_ZeroSize(OpTest): + def setUp(self): + self.op_type = "multiplex" + self.init_dtype() + self.python_api = paddle.tensor.multiplex + rows = 4 + index = np.array([0, 2, 2, 3]).astype('int32') + np.random.shuffle(index) + index = np.reshape(index, (rows, 1)) + ins1 = np.random.random((rows, 0)).astype(self.dtype) + ins2 = np.random.random((rows, 0)).astype(self.dtype) + ins3 = np.random.random((rows, 0)).astype(self.dtype) + ins4 = np.random.random((rows, 0)).astype(self.dtype) + self.inputs = { + 'Ids': index, + 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)], + } + # multiplex output + output = np.zeros_like(ins1) + for i in range(0, rows): + k = index[i][0] + if self.inputs['X'][k][1][i].size != 0: + output[i] = self.inputs['X'][k][1][i] + self.outputs = {'Out': output} + + def init_dtype(self): + self.dtype = 'float64' + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out', check_pir=True) + + if __name__ == '__main__': unittest.main()