-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【SCU】【Paddle Tensor 第二期 API 支持 0-size TensorNo.25、47】paddle.tensordot 支持 0-size Tensor #70238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8c894c2
323ea6c
c7b9802
d331a47
0855d5d
6e3ac40
28fdf59
b957eaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2007,16 +2007,36 @@ void MatmulKernel(const Context& ctx, | |
| bool transpose_x, | ||
| bool transpose_y, | ||
| DenseTensor* out) { | ||
| PADDLE_ENFORCE_NE(common::product(x.dims()), | ||
| 0, | ||
| common::errors::InvalidArgument( | ||
| "The Input(X) dims size must not be equal " | ||
| "0, but received dims size is 0.")); | ||
| PADDLE_ENFORCE_NE(common::product(y.dims()), | ||
| 0, | ||
| common::errors::InvalidArgument( | ||
| "The Input(Y) dims size must not be equal " | ||
| "0, but received dims size is 0.")); | ||
| if (x.numel() == 0 || y.numel() == 0) { | ||
| auto x_dims = x.dims(); | ||
| auto y_dims = y.dims(); | ||
| if (transpose_x) { | ||
| std::swap(x_dims[x_dims.size() - 1], x_dims[x_dims.size() - 2]); | ||
| } | ||
| if (transpose_y) { | ||
| std::swap(y_dims[y_dims.size() - 1], y_dims[y_dims.size() - 2]); | ||
| } | ||
| std::vector<std::int64_t> out_dims(x_dims.size() - 1 + y_dims.size() - 1); | ||
| for (int64_t i = 0; i < x_dims.size() - 1; ++i) { | ||
| out_dims[i] = x_dims[i]; | ||
| } | ||
| for (int64_t i = 1; i < y_dims.size(); ++i) { | ||
| out_dims[x_dims.size() - 1 + i - 1] = y_dims[i]; | ||
| } | ||
| out->Resize(phi::make_ddim(out_dims)); | ||
| ctx.template Alloc<T>(out); | ||
| return; | ||
| } | ||
| PADDLE_ENFORCE_GE( | ||
| common::product(x.dims()), | ||
| 0, | ||
| common::errors::InvalidArgument( | ||
| "The dims of Input(X) should be greater than or equal to 0.")); | ||
| PADDLE_ENFORCE_GE( | ||
| common::product(y.dims()), | ||
| 0, | ||
| common::errors::InvalidArgument( | ||
| "The dims of Input(Y) should be greater than or equal to 0.")); | ||
|
Comment on lines
+2010
to
+2039
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. matmul_kernel如果仅修改PADDLE_ENFORCE_NE,是否能支持0-size Tensor?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好像不太行欸,会提示:** On entry to DGEMM parameter number 8 had an illegal value |
||
| const std::vector<std::int64_t> x_dims = common::vectorize(x.dims()); | ||
| const std::vector<std::int64_t> y_dims = common::vectorize(y.dims()); | ||
| MatmulJudgeDtypeKernel<Context, T>( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -371,5 +371,116 @@ def set_dtype(self): | |
| self.dtype = np.float64 | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSize(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [0, 5, 5, 5] | ||
| self.y_shape = [0, 5, 5, 5] | ||
|
Comment on lines
+376
to
+377
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否有其他形状组合?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已补充 |
||
|
|
||
| def set_input_data(self): | ||
| self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
| self.y = np.random.random(self.y_shape).astype(self.dtype) | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestTensordotAPIFloat64ZeroSize(TestTensordotAPIZeroSize): | ||
| def set_dtype(self): | ||
| self.dtype = np.float64 | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSize(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [0, 5, 5, 5] | ||
| self.y_shape = [0, 5, 5, 5] | ||
|
|
||
| def set_input_data(self): | ||
| self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
| self.y = np.random.random(self.y_shape).astype(self.dtype) | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
| def set_dtype(self): | ||
| self.dtype = np.float64 | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSizeMultipleDims1(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [0, 0, 5, 5] | ||
| self.y_shape = [0, 0, 5, 5] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSizeMultipleDims2(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [5, 0, 5, 0] | ||
| self.y_shape = [5, 0, 5, 0] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSizeDifferentDims1(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [5, 5, 0, 5] | ||
| self.y_shape = [5, 5, 0, 5] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestTensordotAPIZeroSizeDifferentDims2(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [5, 5, 5, 0] | ||
| self.y_shape = [5, 5, 5, 0] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestTensordotAPISingleElementAndZeroSize(TestTensordotAPI): | ||
| def set_input_shape(self): | ||
| self.x_shape = [1, 5, 5, 5] | ||
| self.y_shape = [0, 5, 5, 5] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [ | ||
| [[], []], | ||
| ] | ||
|
|
||
|
|
||
| class TestBroadcastWithZeroSize1(unittest.TestCase): | ||
| def setUp(self): | ||
| self.x_shape = [5, 0, 3] | ||
| self.y_shape = [3, 4, 0] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [[], []] | ||
|
|
||
|
|
||
| class TestBroadcastWithZeroSize2(unittest.TestCase): | ||
| def setUp(self): | ||
| self.x_shape = [5, 0, 3] | ||
| self.y_shape = [3, 0] | ||
|
|
||
| def set_test_axes(self): | ||
| self.all_axes = [[], []] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的输出计算逻辑,可以验证下batch广播的情况是否正确?

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充,也可以过单测的