Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions paddle/phi/kernels/impl/solve_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ void SolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
if (x.numel() == 0 || y.numel() == 0) {
auto x_dims = x.dims();
auto y_dims = y.dims();
std::vector<int> out_dims;
if (y_dims.size() == 1) {
out_dims =
std::vector<int>(x_dims.Get(), x_dims.Get() + x_dims.size() - 2);
out_dims.push_back(y_dims[y_dims.size() - 1]);
Comment on lines +202 to +205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支有测试到吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支有测试到吗?

这个好像测不到,python层solve函数有一个检查输入格式的函数了,如果y_dims.size() == 1会raise ValueError,已经添加单测了)

} else {
// broadcast
std::vector<int> x_shape(x_dims.Get(), x_dims.Get() + x_dims.size() - 2);
std::vector<int> y_shape(y_dims.Get(), y_dims.Get() + y_dims.size() - 2);
auto x_it = x_shape.rbegin();
auto y_it = y_shape.rbegin();
while (x_it != x_shape.rend() || y_it != y_shape.rend()) {
int x_dim = (x_it != x_shape.rend()) ? *x_it : 1;
int y_dim = (y_it != y_shape.rend()) ? *y_it : 1;
out_dims.push_back(std::max(x_dim, y_dim));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有个问题,广播时如果 xdim=1,ydim=0,那么结果的outdim应该是0而不是1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (x_it != x_shape.rend()) ++x_it;
if (y_it != y_shape.rend()) ++y_it;
}
std::reverse(out_dims.begin(), out_dims.end());
out_dims.insert(out_dims.end(),
y_dims.Get() + y_dims.size() - 2,
y_dims.Get() + y_dims.size());
}
out->Resize(phi::make_ddim(out_dims));
dev_ctx.template Alloc<T>(out);
return;
}
linalg_solve<Context, T>(dev_ctx, x, y, out);
}

Expand Down
68 changes: 68 additions & 0 deletions test/legacy_test/test_solve_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,5 +923,73 @@ def test_dygraph(self):
print("The mat is singular")


class TestSolveOpAPIZeroDimCase1(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = []
self.dtype = "float32"
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
self.place.append(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def check_static_result(self, place):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
paddle_input_x = paddle.static.data(
name="input_x", shape=[10, 0, 0], dtype=self.dtype
)
paddle_input_y = paddle.static.data(
name="input_y", shape=[6, 0, 0], dtype=self.dtype
) # broadcast
paddle_result = paddle.linalg.solve(
paddle_input_x, paddle_input_y, left=False
)

np_input_x = np.random.random([10, 0, 0]).astype(self.dtype)
np_input_y = np.random.random([10, 0, 0]).astype(self.dtype)

np_result = np_solve_right(np_input_x, np_input_y)

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input_x": np_input_x, "input_y": np_input_y},
fetch_list=[paddle_result],
)
np.testing.assert_allclose(fetches[0], np_result, rtol=0.0001)

def test_static(self):
for place in self.place:
self.check_static_result(place=place)

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改用with xxguard代替disable/enable static

np.random.seed(2021)
input_x_np = np.random.random([10, 0, 0]).astype(self.dtype)
input_y_np = np.random.random([10, 0, 0]).astype(self.dtype)

tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)

numpy_output = np_solve_right(input_x_np, input_y_np)
paddle_output = paddle.linalg.solve(
tensor_input_x, tensor_input_y, left=False
)
np.testing.assert_allclose(
numpy_output, paddle_output.numpy(), rtol=0.0001
)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()

for place in self.place:
run(place)


if __name__ == "__main__":
unittest.main()
Loading