-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move index sample #39905
Move index sample #39905
Conversation
Thanks for your contribution! |
…ddle into move_index_sample
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
|
|||
#include "paddle/fluid/operators/index_sample_op.h" | |||
|
|||
//#include "paddle/fluid/operators/index_sample_op.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的注释可以删除
paddle/phi/infermeta/binary.cc
Outdated
// if (ctx->IsRuntime()) { | ||
PADDLE_ENFORCE_EQ(input_dims[0], | ||
index_dims[0], | ||
errors::InvalidArgument( | ||
"Inputs(X)'s value of dimension 0 must same with " | ||
"Inputs(Index)'s value of dimension 0, but " | ||
"got %d of Inputs(X), and got %d of Inputs(Index), " | ||
"please check Inputs shape.", | ||
input_dims[0], | ||
index_dims[0])); | ||
//} | ||
// ctx->SetOutputDim("Out", index_dims); | ||
out->set_dims(index_dims); | ||
// auto type = ctx->GetInputsVarType("Index")[0]; | ||
// if (type == framework::proto::VarType::LOD_TENSOR) { | ||
// ctx->ShareLoD("Index", /*->*/ "Out"); | ||
// } | ||
out->share_lod(y); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释部分可删除
如果有if (ctx->IsRuntime())
的逻辑建议传入MetaConfig
参数进行判断,可参考UnfoldInferMeta
,直接去掉CE可能会有问题
int v_i = b * value_length + static_cast<int>(index_vec[i]); | ||
x_grad_vec[v_i] += out_grad_vec[i]; | ||
} | ||
x_grad->mutable_data<T>(context.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用context.template Alloc<T>()
替换mutable_data<T>
const DenseTensor& x, | ||
const DenseTensor& index, | ||
DenseTensor* x_grad) { | ||
const auto& index_type = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要再转成ProtoVarType了,下面直接判断DataType是否相等就行
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
#include "paddle/phi/kernels/funcs/common_shape.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有些头文件看上去似乎没有使用?
using Tensor = paddle::framework::Tensor; | ||
using LoDTensor = paddle::framework::LoDTensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可移除
T* input_grad_data = x_grad->mutable_data<T>(ctx.GetPlace()); | ||
|
||
const auto& index_type = | ||
paddle::framework::TransToProtoVarType(index.dtype()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,也不需要转换
using Tensor = paddle::framework::Tensor; | ||
using LoDTensor = paddle::framework::LoDTensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可移除
// auto* input = ctx.Input<LoDTensor>("X"); | ||
// auto* index = ctx.Input<LoDTensor>("Index"); | ||
// auto* output = ctx.Output<LoDTensor>("Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释可删除
|
||
KernelSignature IndexSampleOpArgumentMapping( | ||
const ArgumentMappingContext& ctx) { | ||
return KernelSignature("index_sample", {"X", "Index"}, {}, {"Out"}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里前向的Mapping映射应该可以不写,使用默认op_proto映射即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尝试了一下,这句删了,编译不能通过
input_dims[0], | ||
index_dims[0])); | ||
} | ||
out->set_dims(index_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eigen和math_function的相关函数似乎没看到使用的地方,可以再确认下,如果没有用的话可以去掉
|
||
#include "paddle/fluid/framework/convert_utils.h" | ||
namespace phi { | ||
using DataType = paddle::experimental::DataType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该可以直接用DataType,不需要using了
|
||
#include "paddle/fluid/framework/tensor_util.h" | ||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/backends/cpu/cpu_context.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重复include了
#include "paddle/phi/kernels/funcs/diagonal.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个再check下是否有使用
|
||
#include "paddle/fluid/framework/convert_utils.h" | ||
namespace phi { | ||
using DataType = paddle::experimental::DataType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using可以去掉,可以直接使用DataType
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/complex_functors.h" | ||
#include "paddle/phi/kernels/funcs/elementwise_base.h" | ||
#include "paddle/phi/kernels/index_sample_grad_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index_sample_grad_kernel.h
放在最开头
#include "paddle/phi/kernels/funcs/complex_functors.h" | ||
#include "paddle/phi/kernels/funcs/elementwise_base.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个头文件check下是否有使用
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/data_type.h" | ||
#include "paddle/phi/core/dense_tensor.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_type.h和dense_tensor.h应该不需要include了
#include "paddle/phi/common/data_type.h" | ||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/complex_functors.h" | ||
#include "paddle/phi/kernels/funcs/elementwise_base.h" | ||
#include "paddle/phi/kernels/index_sample_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op benchmark
PR types
Others
PR changes
Others
Describe
move index sample op.