-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【SCU】【Paddle Tensor 第二期 API 支持 0-size TensorNo.46】paddle.linalg.solve 支持 0-size Tensor #70575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| if (y_dims.size() == 1) { | ||
| out_dims = | ||
| std::vector<int>(x_dims.Get(), x_dims.Get() + x_dims.size() - 2); | ||
| out_dims.push_back(y_dims[y_dims.size() - 1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支有测试到吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支有测试到吗?
这个好像测不到,python层solve函数有一个检查输入格式的函数了,如果y_dims.size() == 1会raise ValueError,已经添加单测了)
| int x_dim = (x_it != x_shape.rend()) ? *x_it : 1; | ||
| int y_dim = (y_it != y_shape.rend()) ? *y_it : 1; | ||
| out_dims.push_back(std::max(x_dim, y_dim)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有个问题,广播时如果 xdim=1,ydim=0,那么结果的outdim应该是0而不是1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
test/legacy_test/test_solve_op.py
Outdated
|
|
||
| def test_dygraph(self): | ||
| def run(place, x_shape, y_shape): | ||
| paddle.disable_static(place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改用with xxguard代替disable/enable static
| if (x.numel() == 0 || y.numel() == 0) { | ||
| auto x_dims = x.dims(); | ||
| auto y_dims = y.dims(); | ||
| std::vector<int> out_dims; | ||
| if (y_dims.size() == 1) { | ||
| out_dims = | ||
| std::vector<int>(x_dims.Get(), x_dims.Get() + x_dims.size() - 2); | ||
| out_dims.push_back(y_dims[y_dims.size() - 1]); | ||
| } else { | ||
| // broadcast | ||
| std::vector<int> x_shape(x_dims.Get(), x_dims.Get() + x_dims.size() - 2); | ||
| std::vector<int> y_shape(y_dims.Get(), y_dims.Get() + y_dims.size() - 2); | ||
| auto x_it = x_shape.rbegin(); | ||
| auto y_it = y_shape.rbegin(); | ||
| while (x_it != x_shape.rend() || y_it != y_shape.rend()) { | ||
| int x_dim = (x_it != x_shape.rend()) ? *x_it : 1; | ||
| int y_dim = (y_it != y_shape.rend()) ? *y_it : 1; | ||
| if (x_dim == 0 || y_dim == 0) { | ||
| out_dims.push_back(0); | ||
| } else { | ||
| out_dims.push_back(std::max(x_dim, y_dim)); | ||
| } | ||
| if (x_it != x_shape.rend()) ++x_it; | ||
| if (y_it != y_shape.rend()) ++y_it; | ||
| } | ||
| std::reverse(out_dims.begin(), out_dims.end()); | ||
| out_dims.insert(out_dims.end(), | ||
| y_dims.Get() + y_dims.size() - 2, | ||
| y_dims.Get() + y_dims.size()); | ||
| } | ||
| out->Resize(phi::make_ddim(out_dims)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码能移动到solve的infermeta里面吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码能移动到solve的infermeta里面吗?
好像不太行诶,移到binary.cc里面输入0-size tensor就报错了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码能移动到solve的infermeta里面吗?
好像不太行诶,移到binary.cc里面输入0-size tensor就报错了
好的
HydrogenSulfate
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Others
Description
单测可以通过0-size的输入,但是

array_api_tests可能存在配置问题(在@pytest.mark.xp_extension('linalg')语句报错),先跑个CI看看。