-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【SCU】【Paddle Tensor 第二期 API 支持 0-size TensorNo.46】paddle.linalg.solve 支持 0-size Tensor #70575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6fa5dfe
f50b146
24e34eb
792ddc7
ee94358
2091048
e66fe1a
266cc62
4bd9ecd
c65dd6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,6 +195,36 @@ void SolveKernel(const Context& dev_ctx, | |
| const DenseTensor& x, | ||
| const DenseTensor& y, | ||
| DenseTensor* out) { | ||
| if (x.numel() == 0 || y.numel() == 0) { | ||
| auto x_dims = x.dims(); | ||
| auto y_dims = y.dims(); | ||
| std::vector<int> out_dims; | ||
| if (y_dims.size() == 1) { | ||
| out_dims = | ||
| std::vector<int>(x_dims.Get(), x_dims.Get() + x_dims.size() - 2); | ||
| out_dims.push_back(y_dims[y_dims.size() - 1]); | ||
| } else { | ||
| // broadcast | ||
| std::vector<int> x_shape(x_dims.Get(), x_dims.Get() + x_dims.size() - 2); | ||
| std::vector<int> y_shape(y_dims.Get(), y_dims.Get() + y_dims.size() - 2); | ||
| auto x_it = x_shape.rbegin(); | ||
| auto y_it = y_shape.rbegin(); | ||
| while (x_it != x_shape.rend() || y_it != y_shape.rend()) { | ||
| int x_dim = (x_it != x_shape.rend()) ? *x_it : 1; | ||
| int y_dim = (y_it != y_shape.rend()) ? *y_it : 1; | ||
| out_dims.push_back(std::max(x_dim, y_dim)); | ||
|
||
| if (x_it != x_shape.rend()) ++x_it; | ||
| if (y_it != y_shape.rend()) ++y_it; | ||
| } | ||
| std::reverse(out_dims.begin(), out_dims.end()); | ||
| out_dims.insert(out_dims.end(), | ||
| y_dims.Get() + y_dims.size() - 2, | ||
| y_dims.Get() + y_dims.size()); | ||
| } | ||
| out->Resize(phi::make_ddim(out_dims)); | ||
| dev_ctx.template Alloc<T>(out); | ||
| return; | ||
| } | ||
| linalg_solve<Context, T>(dev_ctx, x, y, out); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -923,5 +923,73 @@ def test_dygraph(self): | |
| print("The mat is singular") | ||
|
|
||
|
|
||
| class TestSolveOpAPIZeroDimCase1(unittest.TestCase): | ||
| def setUp(self): | ||
| np.random.seed(2021) | ||
| self.place = [] | ||
| self.dtype = "float32" | ||
| if ( | ||
| os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() | ||
| in ['1', 'true', 'on'] | ||
| or not core.is_compiled_with_cuda() | ||
| ): | ||
| self.place.append(paddle.CPUPlace()) | ||
| if core.is_compiled_with_cuda(): | ||
| self.place.append(paddle.CUDAPlace(0)) | ||
|
|
||
| def check_static_result(self, place): | ||
| paddle.enable_static() | ||
| with base.program_guard(base.Program(), base.Program()): | ||
| paddle_input_x = paddle.static.data( | ||
| name="input_x", shape=[10, 0, 0], dtype=self.dtype | ||
| ) | ||
| paddle_input_y = paddle.static.data( | ||
| name="input_y", shape=[6, 0, 0], dtype=self.dtype | ||
| ) # broadcast | ||
| paddle_result = paddle.linalg.solve( | ||
| paddle_input_x, paddle_input_y, left=False | ||
| ) | ||
|
|
||
| np_input_x = np.random.random([10, 0, 0]).astype(self.dtype) | ||
| np_input_y = np.random.random([10, 0, 0]).astype(self.dtype) | ||
|
|
||
| np_result = np_solve_right(np_input_x, np_input_y) | ||
|
|
||
| exe = base.Executor(place) | ||
| fetches = exe.run( | ||
| base.default_main_program(), | ||
| feed={"input_x": np_input_x, "input_y": np_input_y}, | ||
| fetch_list=[paddle_result], | ||
| ) | ||
| np.testing.assert_allclose(fetches[0], np_result, rtol=0.0001) | ||
|
|
||
| def test_static(self): | ||
| for place in self.place: | ||
| self.check_static_result(place=place) | ||
|
|
||
| def test_dygraph(self): | ||
| def run(place): | ||
| paddle.disable_static(place) | ||
|
||
| np.random.seed(2021) | ||
| input_x_np = np.random.random([10, 0, 0]).astype(self.dtype) | ||
| input_y_np = np.random.random([10, 0, 0]).astype(self.dtype) | ||
|
|
||
| tensor_input_x = paddle.to_tensor(input_x_np) | ||
| tensor_input_y = paddle.to_tensor(input_y_np) | ||
|
|
||
| numpy_output = np_solve_right(input_x_np, input_y_np) | ||
| paddle_output = paddle.linalg.solve( | ||
| tensor_input_x, tensor_input_y, left=False | ||
| ) | ||
| np.testing.assert_allclose( | ||
| numpy_output, paddle_output.numpy(), rtol=0.0001 | ||
| ) | ||
| self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) | ||
| paddle.enable_static() | ||
|
|
||
| for place in self.place: | ||
| run(place) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支有测试到吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个好像测不到,python层solve函数有一个检查输入格式的函数了,如果y_dims.size() == 1会raise ValueError,已经添加单测了)