Skip to content

Conversation

@co63oc
Copy link
Contributor

@co63oc co63oc commented Jul 24, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

输入shape一致都为0-size,判断out numel返回

PaddleAPITest中为自定义规则torch.stack
https://github.com/PFCCLab/PaddleAPITest/blob/81e9f10d78ec53b8a56684b75bcdcf108f38a43f/tester/paddle_to_torch/rules.py#L3795
image

增加单测
PaddleAPITest测试通过,错误为 torch error
image

@paddle-bot
Copy link

paddle-bot bot commented Jul 24, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jul 24, 2025
@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label Jul 24, 2025
@co63oc
Copy link
Contributor Author

co63oc commented Jul 24, 2025

/re-run all-failed

@codecov-commenter
Copy link

codecov-commenter commented Jul 24, 2025

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@98204ab). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddle/phi/kernels/cpu/multiplex_kernel.cc 0.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (0.00%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #74212   +/-   ##
==========================================
  Coverage           ?    0.00%           
==========================================
  Files              ?        1           
  Lines              ?        1           
  Branches           ?        0           
==========================================
  Hits               ?        0           
  Misses             ?        1           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@co63oc
Copy link
Contributor Author

co63oc commented Jul 24, 2025

/re-run all-failed

@co63oc
Copy link
Contributor Author

co63oc commented Jul 25, 2025

@DanielSun11 CI已完成需要review

ins.size(),
errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
if (ins[k]->numel() == 0) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有的tensor都不为0-size或者都为0-size,这里处理的是所有的tensor numel都为0,都会跳过

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也就是说修改后的paddle.multiplex只能支持输入全部为0 size或者输入全部为非0 size的情况吗?如果是这样的,infermeta和符号推导就没有必要改了。https://www.paddlepaddle.org.cn/documentation/docs/zh/3.0-beta/api/paddle/multiplex_cn.html#multiplex 规定了所有输入的shape必须相同,如果输入为0 size则其所有输入的shape也要求相同。只需要在kernel中判断如果out的numel为0,就直接return,就可以了吧?shape不同的为非法case

Copy link
Contributor Author

@co63oc co63oc Jul 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也就是说修改后的paddle.multiplex只能支持输入全部为0 size或者输入全部为非0 size的情况吗?如果是这样的,infermeta和符号推导就没有必要改了。https://www.paddlepaddle.org.cn/documentation/docs/zh/3.0-beta/api/paddle/multiplex_cn.html#multiplex 规定了所有输入的shape必须相同,如果输入为0 size则其所有输入的shape也要求相同。只需要在kernel中判断如果out的numel为0,就直接return,就可以了吧?shape不同的为非法case

输入可以包含0-size和非0-size,输出取非0-size的shape,但是index选择的时候,选择的是非0-size项,如果都是选择0-size项,kernel结尾要Resize,这里看下torch的规则大概可以明白,构造temp时都是非0-size或者都是0-size,但是输入可以是不同的
image

文档描述shape是一致,但是0-size特例也可能出现,所以这里加了处理

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle
x1 = paddle.randn([3, 3],dtype=paddle.float32)
x2 = paddle.randn([3, 0],dtype=paddle.float32)
print(x1)
print(x2)
index = paddle.to_tensor([[0],[1]],dtype=paddle.int32)
z = paddle.multiplex([x1, x2], index)
print(z)

请验证下这个代码的输出

Copy link
Contributor Author

@co63oc co63oc Jul 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那还是按简单的处理,规则复杂也不容易维护,输入shape都一致 测试用例修改PR PFCCLab/PaddleAPITest#444

@co63oc
Copy link
Contributor Author

co63oc commented Jul 26, 2025

/re-run all-failed

Copy link
Contributor

@DanielSun11 DanielSun11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ins.size(),
errors::PreconditionNotMet(
"index exceeds the number of candidate tensors."));
if (ins[k]->numel() == 0) continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle
x1 = paddle.randn([3, 3],dtype=paddle.float32)
x2 = paddle.randn([3, 0],dtype=paddle.float32)
print(x1)
print(x2)
index = paddle.to_tensor([[0],[1]],dtype=paddle.int32)
z = paddle.multiplex([x1, x2], index)
print(z)

请验证下这个代码的输出

@DanielSun11 DanielSun11 merged commit 3e59330 into PaddlePaddle:develop Jul 28, 2025
68 of 72 checks passed
@co63oc co63oc deleted the h52 branch July 29, 2025 23:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants