-
Notifications
You must be signed in to change notification settings - Fork 5.7k
【BUPT】【Paddle Tensor 第二期 API支持 0-size Tensor】paddle.matrix_power 支持 0-size tensor #70098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
bool has_zero_dim = false; | ||
for (int i = 0; i < x_ndim; i++) { | ||
if (x_dims[i] == 0) { | ||
has_zero_dim = true; | ||
break; | ||
} | ||
} | ||
if (has_zero_dim) { | ||
Out->Resize(X->dims()); | ||
ctx.template Alloc<T>(Out); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以化简下
bool has_zero_dim = false; | |
for (int i = 0; i < x_ndim; i++) { | |
if (x_dims[i] == 0) { | |
has_zero_dim = true; | |
break; | |
} | |
} | |
if (has_zero_dim) { | |
Out->Resize(X->dims()); | |
ctx.template Alloc<T>(Out); | |
return; | |
} | |
if (x->numel() == 0) { | |
Out->Resize(X->dims()); | |
ctx.template Alloc<T>(Out); | |
return; | |
} |
PADDLE_ENFORCE_EQ( | ||
x_dims[x_ndim - 2], | ||
x_dims[x_ndim - 1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面的逻辑应该放在203行的检查之后,0-size Tensor也应该符合方阵的要求
|
||
def _test_matrix_power_empty_dynamtic(self): | ||
with dygraph_guard(): | ||
x = paddle.full((0, 0), 1.0, dtype='float32') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以多测试几组,如[0,0]、[0,6]、[6,0]、[2,3,0,0]这种形状
assert len(y4.shape) == 4 and y4.shape[0] == 2 and y4.shape[1] == 3 | ||
assert len(y.shape) == 2 and y.shape[0] == 0 and y.shape[1] == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 使用self.assertEqual而不是assert
- 对于形状的判断,直接使用:y4.shape == [2, 3, 0, 0]即可,没必要把每个元素都写出来,y的形状判断同理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Bug fixes
Description
解决x.dim=0出现的报错
