Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions paddle/phi/kernels/impl/matmul_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2007,16 +2007,36 @@ void MatmulKernel(const Context& ctx,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(common::product(x.dims()),
0,
common::errors::InvalidArgument(
"The Input(X) dims size must not be equal "
"0, but received dims size is 0."));
PADDLE_ENFORCE_NE(common::product(y.dims()),
0,
common::errors::InvalidArgument(
"The Input(Y) dims size must not be equal "
"0, but received dims size is 0."));
if (x.numel() == 0 || y.numel() == 0) {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (transpose_x) {
std::swap(x_dims[x_dims.size() - 1], x_dims[x_dims.size() - 2]);
}
if (transpose_y) {
std::swap(y_dims[y_dims.size() - 1], y_dims[y_dims.size() - 2]);
}
std::vector<std::int64_t> out_dims(x_dims.size() - 1 + y_dims.size() - 1);
for (int64_t i = 0; i < x_dims.size() - 1; ++i) {
out_dims[i] = x_dims[i];
}
for (int64_t i = 1; i < y_dims.size(); ++i) {
out_dims[x_dims.size() - 1 + i - 1] = y_dims[i];
}
out->Resize(phi::make_ddim(out_dims));
ctx.template Alloc<T>(out);
return;
}
Comment on lines +2010 to +2029
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的输出计算逻辑,可以验证下batch广播的情况是否正确?
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充,也可以过单测的

PADDLE_ENFORCE_GE(
common::product(x.dims()),
0,
common::errors::InvalidArgument(
"The dims of Input(X) should be greater than or equal to 0."));
PADDLE_ENFORCE_GE(
common::product(y.dims()),
0,
common::errors::InvalidArgument(
"The dims of Input(Y) should be greater than or equal to 0."));
Comment on lines +2010 to +2039
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul_kernel如果仅修改PADDLE_ENFORCE_NE,是否能支持0-size Tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像不太行欸,会提示:** On entry to DGEMM parameter number 8 had an illegal value

const std::vector<std::int64_t> x_dims = common::vectorize(x.dims());
const std::vector<std::int64_t> y_dims = common::vectorize(y.dims());
MatmulJudgeDtypeKernel<Context, T>(
Expand Down
111 changes: 111 additions & 0 deletions test/legacy_test/test_tensordot.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,116 @@ def set_dtype(self):
self.dtype = np.float64


class TestTensordotAPIZeroSize(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [0, 5, 5, 5]
self.y_shape = [0, 5, 5, 5]
Comment on lines +376 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否有其他形状组合?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充


def set_input_data(self):
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestTensordotAPIFloat64ZeroSize(TestTensordotAPIZeroSize):
def set_dtype(self):
self.dtype = np.float64


class TestTensordotAPIZeroSize(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [0, 5, 5, 5]
self.y_shape = [0, 5, 5, 5]

def set_input_data(self):
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)

def set_test_axes(self):
self.all_axes = [
[[], []],
]

def set_dtype(self):
self.dtype = np.float64


class TestTensordotAPIZeroSizeMultipleDims1(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [0, 0, 5, 5]
self.y_shape = [0, 0, 5, 5]

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestTensordotAPIZeroSizeMultipleDims2(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [5, 0, 5, 0]
self.y_shape = [5, 0, 5, 0]

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestTensordotAPIZeroSizeDifferentDims1(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [5, 5, 0, 5]
self.y_shape = [5, 5, 0, 5]

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestTensordotAPIZeroSizeDifferentDims2(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [5, 5, 5, 0]
self.y_shape = [5, 5, 5, 0]

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestTensordotAPISingleElementAndZeroSize(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [1, 5, 5, 5]
self.y_shape = [0, 5, 5, 5]

def set_test_axes(self):
self.all_axes = [
[[], []],
]


class TestBroadcastWithZeroSize1(unittest.TestCase):
def setUp(self):
self.x_shape = [5, 0, 3]
self.y_shape = [3, 4, 0]

def set_test_axes(self):
self.all_axes = [[], []]


class TestBroadcastWithZeroSize2(unittest.TestCase):
def setUp(self):
self.x_shape = [5, 0, 3]
self.y_shape = [3, 0]

def set_test_axes(self):
self.all_axes = [[], []]


if __name__ == "__main__":
unittest.main()