-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Phi] move ops: maxout/take_along_axis/put_along_axis #39959
Conversation
Thanks for your contribution! |
f84aed8
to
bb975e5
Compare
bb975e5
to
0556f4e
Compare
0556f4e
to
98b829f
Compare
e6acbcf
to
a457ba8
Compare
8193eb5
to
ec3e04f
Compare
ec3e04f
to
8aef993
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个op_benchmark的脚本有补充吗,没有的话还需要加一下
const auto& index_type = | ||
paddle::framework::TransToProtoVarType(index.dtype()); | ||
if (x_grad) { | ||
paddle::framework::TensorCopy(out_grad, dev_ctx.GetPlace(), x_grad); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以尝试替换成phi下的CopyKernel
这个三个 op 的 benchmark 都需要加下 |
benchmark: PaddlePaddle/benchmark#1326 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
ele = ele > x ? ele : x; | ||
} | ||
template <typename DeviceContext, typename T> | ||
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个MaxOutFunctor看了下只有maxout_op在用,建议也迁移到phi下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个pr单独迁移下
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include一下对应的头文件
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
true, | ||
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU.")); | ||
|
||
const auto& index_type = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要转成ProtoVarType了,可以直接用DataType类型进行判断
|
||
if (value_grad) { | ||
value_grad->Resize(index.dims()); | ||
value_grad->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议使用dev_ctx.template Alloc<T>()
} | ||
if (value_grad) { | ||
value_grad->Resize(index.dims()); | ||
value_grad->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dev_ctx. Alloc
errors::PreconditionNotMet( | ||
"PutAlongAxisCUDAKernel only runs on GPU device.")); | ||
|
||
const auto& index_type = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
KernelSignature MaxoutArgumentMapping(const ArgumentMappingContext& ctx) { | ||
return KernelSignature("maxout", {"X"}, {"groups", "axis"}, {"Out"}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的前向ArgumentMapping感觉可以不写,使用默认的op_proto应该也能work
KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) { | ||
return KernelSignature("put_along_axis", | ||
{"Input", "Index", "Value"}, | ||
{"Axis", "Reduce"}, | ||
{"Result"}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
KernelSignature TakeAlongAxisArgumentMapping( | ||
const ArgumentMappingContext& ctx) { | ||
return KernelSignature( | ||
"take_along_axis", {"Input", "Index"}, {"Axis"}, {"Result"}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for PADDLE_ENFORCE
还有一些细节问题,麻烦后续PR追加完善 |
ok,下个 pr 补充下 |
PR types
Others
PR changes
Others
Describe
move the following ops: