Skip to content

Commit 1cb0114

Browse files
authored
【SCU】【Paddle Tensor 第二期 API 支持 0-size TensorNo.46】paddle.linalg.solve 支持 0-size Tensor (#70575)
* support_0size * fix codestyle * Update solve_kernel_impl.h * update * fix codestyle * Update test_solve_op.py * Update test_solve_op.py * Update test_solve_op.py
1 parent 29c3d91 commit 1cb0114

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

paddle/phi/kernels/impl/solve_kernel_impl.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,40 @@ void SolveKernel(const Context& dev_ctx,
195195
const DenseTensor& x,
196196
const DenseTensor& y,
197197
DenseTensor* out) {
198+
if (x.numel() == 0 || y.numel() == 0) {
199+
auto x_dims = x.dims();
200+
auto y_dims = y.dims();
201+
std::vector<int> out_dims;
202+
if (y_dims.size() == 1) {
203+
out_dims =
204+
std::vector<int>(x_dims.Get(), x_dims.Get() + x_dims.size() - 2);
205+
out_dims.push_back(y_dims[y_dims.size() - 1]);
206+
} else {
207+
// broadcast
208+
std::vector<int> x_shape(x_dims.Get(), x_dims.Get() + x_dims.size() - 2);
209+
std::vector<int> y_shape(y_dims.Get(), y_dims.Get() + y_dims.size() - 2);
210+
auto x_it = x_shape.rbegin();
211+
auto y_it = y_shape.rbegin();
212+
while (x_it != x_shape.rend() || y_it != y_shape.rend()) {
213+
int x_dim = (x_it != x_shape.rend()) ? *x_it : 1;
214+
int y_dim = (y_it != y_shape.rend()) ? *y_it : 1;
215+
if (x_dim == 0 || y_dim == 0) {
216+
out_dims.push_back(0);
217+
} else {
218+
out_dims.push_back(std::max(x_dim, y_dim));
219+
}
220+
if (x_it != x_shape.rend()) ++x_it;
221+
if (y_it != y_shape.rend()) ++y_it;
222+
}
223+
std::reverse(out_dims.begin(), out_dims.end());
224+
out_dims.insert(out_dims.end(),
225+
y_dims.Get() + y_dims.size() - 2,
226+
y_dims.Get() + y_dims.size());
227+
}
228+
out->Resize(phi::make_ddim(out_dims));
229+
dev_ctx.template Alloc<T>(out);
230+
return;
231+
}
198232
linalg_solve<Context, T>(dev_ctx, x, y, out);
199233
}
200234

test/legacy_test/test_solve_op.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,5 +923,87 @@ def test_dygraph(self):
923923
print("The mat is singular")
924924

925925

926+
class TestSolveOpAPIZeroDimCase(unittest.TestCase):
927+
def setUp(self):
928+
np.random.seed(2021)
929+
self.place = []
930+
self.dtype = "float32"
931+
if (
932+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
933+
in ['1', 'true', 'on']
934+
or not core.is_compiled_with_cuda()
935+
):
936+
self.place.append(paddle.CPUPlace())
937+
if core.is_compiled_with_cuda():
938+
self.place.append(paddle.CUDAPlace(0))
939+
940+
def check_static_result(self, place, x_shape, y_shape, np_y_shape):
941+
paddle.enable_static()
942+
with base.program_guard(base.Program(), base.Program()):
943+
paddle_input_x = paddle.static.data(
944+
name="input_x", shape=x_shape, dtype=self.dtype
945+
)
946+
paddle_input_y = paddle.static.data(
947+
name="input_y", shape=y_shape, dtype=self.dtype
948+
)
949+
paddle_result = paddle.linalg.solve(
950+
paddle_input_x, paddle_input_y, left=False
951+
)
952+
953+
np_input_x = np.random.random(x_shape).astype(self.dtype)
954+
np_input_y = np.random.random(np_y_shape).astype(self.dtype)
955+
956+
np_result = np.linalg.solve(np_input_x, np_input_y)
957+
958+
exe = base.Executor(place)
959+
fetches = exe.run(
960+
base.default_main_program(),
961+
feed={"input_x": np_input_x, "input_y": np_input_y},
962+
fetch_list=[paddle_result],
963+
)
964+
np.testing.assert_allclose(fetches[0], np_result, rtol=0.0001)
965+
966+
def test_static(self):
967+
for place in self.place:
968+
self.check_static_result(
969+
place=place,
970+
x_shape=[10, 0, 0],
971+
y_shape=[6, 0, 0],
972+
np_y_shape=[10, 0, 0],
973+
)
974+
with self.assertRaises(ValueError) as context:
975+
self.check_static_result(
976+
place=place,
977+
x_shape=[10, 0, 0],
978+
y_shape=[10],
979+
np_y_shape=[10],
980+
)
981+
982+
def test_dygraph(self):
983+
def run(place, x_shape, y_shape):
984+
with base.dygraph.guard(place):
985+
input_x_np = np.random.random(x_shape).astype(self.dtype)
986+
input_y_np = np.random.random(y_shape).astype(self.dtype)
987+
988+
tensor_input_x = paddle.to_tensor(input_x_np)
989+
tensor_input_y = paddle.to_tensor(input_y_np)
990+
991+
numpy_output = np.linalg.solve(input_x_np, input_y_np)
992+
paddle_output = paddle.linalg.solve(
993+
tensor_input_x, tensor_input_y, left=False
994+
)
995+
np.testing.assert_allclose(
996+
numpy_output, paddle_output.numpy(), rtol=0.0001
997+
)
998+
self.assertEqual(
999+
numpy_output.shape, paddle_output.numpy().shape
1000+
)
1001+
1002+
for place in self.place:
1003+
run(place, x_shape=[10, 0, 0], y_shape=[10, 0, 0])
1004+
with self.assertRaises(ValueError) as context:
1005+
run(place, x_shape=[10, 0, 0], y_shape=[10])
1006+
1007+
9261008
if __name__ == "__main__":
9271009
unittest.main()

0 commit comments

Comments
 (0)