Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen]Reshape Kernel Refactor #37164

Merged

Conversation

YuanRisheng
Copy link
Contributor

PR types

Others

PR changes

OPs

Describe

reshape kernel refactor

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape需要使用vector<int64_t>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape maker定义中使用的vector,使用vector<int64_t>会有兼容问题


// only can include the headers in paddle/top/api dirs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释目录需要修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// will throw error(can't realloc shared memory) in current DenseTensor
// design. So, codes
// below create a tmp densetensor for output.
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#36916 合入之后,会支持realloc,这里需要修改写法,建议记个TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这样的写法可能会对模型性能造成影响,后面需要尽快修改

general::SetXShape(x, xshape);
}

void ReshapeFromDT(const CPUContext& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DT表示什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DenseTensor,代码中已加注释

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数命名后面要迭代一下,原则上是kernel注册name的驼峰式写法,如果不完备的话我们需要进一步考虑规则

Comment on lines 380 to 383
// and this
// will throw error(can't realloc shared memory) in current DenseTensor
// design. So, codes
// below create a tmp densetensor for output.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

短的注释合并到一行更整齐一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 227 to 228
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两处的const & 是不是可以去掉了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉会有bug

"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(static_cast<int32_t>(*tensor.data<int32_t>()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要static_cast<int32_t>转换吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已去除

}
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

枚举类可以直接遍历,不需要转为size_t 再遍历

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接不行吧,需要重载++运算符才可以

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那我记错了

Copy link
Contributor

@MingMingShangTian MingMingShangTian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 895692e into PaddlePaddle:develop Nov 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants