Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Add Common Subexpression Elimination pass #63960

Merged
merged 30 commits into from
May 16, 2024

Conversation

SigureMo
Copy link
Member

@SigureMo SigureMo commented Apr 28, 2024

PR Category

Execute Infrastructure

PR Types

Performance

Description

添加 CSE(公共子表达式消除)pass

为部分 OP 添加 SideEffectTrait(包括 random OP,会改变全局状态,包括 print OP,会对 stdout 造成影响,这些都是 SideEffect,非 pure)

目前 CSE 在动转静前反向 Program 分别 apply,需使用 FLAGS_enable_cse_in_dy2st=1 启用,本 PR 不默认启用,之后 PR 会考虑改为默认

FLAGS_cse_max_count 用于 debug,快速二分哪个 OP 替换导致的精度问题,为实验性 flag

CSE 具体方案如下:

cse drawio

对于如图所示 case,两个 op b 之所以能替换,根本原因是它们的输出是相同的,而这个相同的保证是它们的输入相同,且 OP(类型、属性)是相同的

因此对于一个表达式树而言,是公共子表达式的判别条件即为「叶子结点相同」+「OP 拓扑相同」,为了能够计算这一点,对于两个 op b 而言,它会递归去计算其上游 OP(即 op a)是否相同,当然,op a 的输出的比较并不会用 Value Id,而是比较 OP 以及是该 OP 的第几个输出,当这两个相同的时候,就可以保证该 Value 是相同的

这里实现是基于哈希的,hash(op) = hash(inputs) ^ hash(op_name) ^ hash(op_attrs)hash(value) = hash(defining_op) ^ hash(idx)hash(terminate_value) = hash(value_id),实现这几条 hash 计算即可判断两个 OP 是否是公共子表达式

由于 OP 是基于拓扑序遍历的,因此下游 OP 的上游 OP 信息一定会提前计算好,因此只需要 O(N) 即可计算全部 OP hash 信息,只有在发生哈希冲突的时候才需要 compare(本 PR 尚未实现)

对于含子 block 的情况,子 block 会自动「继承」父 block 已经计算好的表达式信息,父子 block 类似于作用域的可见性,子 block 天然可见父 block 的表达式,那么就可以使用父 block 的表达式来替换子 block 的,反之则因为可见性而不支持

TODOs

本 PR

  • 清理 log,调整 log level
  • 叶子结点直接使用 Value.impl() 作为 id

之后 PR,非重要问题将会在下个 PR 修复

  • 优化代码,复用 unordered_map,以自动解决哈希冲突
  • 使用 trait 来标记支持交换律的 OP
  • 在动转静整图上跑 CSE,在 recompute 之前

PCard-66972

Copy link

paddle-bot bot commented Apr 28, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@SigureMo SigureMo changed the title [WIP][PIR] Add Common Subexpression Elimination pass [PIR] Add Common Subexpression Elimination pass May 14, 2024
@@ -83,6 +84,8 @@ class IR_API OpInfo {
return OpInfo(static_cast<OpInfoImpl *>(pointer));
}

const std::vector<std::string> GetAttributesName() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<std::string> GetAttributesName() const;
std::vector<std::string> GetAttributesName() const;

这里返回值用const修饰是否有意义呢?

Copy link
Member Author

@SigureMo SigureMo May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔喔,没意义,忘记删了 😂,我下一个 PR 删掉吧~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经和其它问题一起删掉了~

Comment on lines +179 to +215
size_t CalcOperationHash(pir::Operation* op) {
PADDLE_ENFORCE_EQ(
registered_ops_info_.count(reinterpret_cast<void*>(op)),
0,
common::errors::PreconditionNotMet(
"The operation %s is already registered in the table, don't call "
"CalcOperationHash twice.",
op->name()));
// hash(op) = hash(operands) ^ hash(name) ^ hash(attributes)
size_t hash = 0;
VLOG(7) << "[CalcOperationHash] op [" << op << "] " << op->name()
<< " start";

std::vector<size_t> values_hash;
for (auto& value : op->operands_source()) {
values_hash.push_back(CalcValueHash(value));
}
if (kCommutativeOps.count(op->name())) {
for (auto& commutative_indices : kCommutativeOps[op->name()]) {
values_hash = SortElementsAtIndices(values_hash, commutative_indices);
}
}
for (auto& value_hash : values_hash) {
hash = pir::detail::hash_combine(hash, value_hash);
}
hash =
pir::detail::hash_combine(hash, std::hash<std::string>{}(op->name()));
for (auto& attr_name : op->info().GetAttributesName()) {
hash =
pir::detail::hash_combine(hash, std::hash<std::string>{}(attr_name));
auto attr = op->attribute(attr_name);
hash = pir::detail::hash_combine(hash, attr.hash());
}
VLOG(7) << "[CalcOperationHash] op [" << op << "] " << op->name()
<< " hash: " << hash;
return hash;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hash值看上去无法保证唯一值,替换时如何确保两个op是等价的?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在 hash 确实是无法保证唯一的,是有概率哈希冲突的,下个 PR 会在发生冲突时使用 compare 来确保两个 OP 是相同的,需要递归遍历整棵表达式树以确保等价(这个在 PR 描述里说明了还没有实现,本地正在重构代码实现中)

const std::vector<size_t>& indices) {
std::vector<T> selected_elements;
for (auto& idx : indices) {
selected_elements.push_back(vec[idx]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议加上PADDLE_ENFORCE_LT(idx, vec.size(), .....),检查越界问题

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加~


pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_info.GetInterfaceImpl 的实现里是有可能返回nullptr的,这里在进一步调用get_op_info之前,建议先PADDLE_ENFORCE_NOT_NULL,否则会有段错误的风险。有些算子可能是手写的,不一定继承了OpYamlInfoInterface

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加如果为 nullptr 的情况直接返回空的逻辑,如果 OP 确实没有这个 interface 的话应该返回空 map

GetOpResultId(value));
}
// hash(termiante_value) = terminate_value_id
return reinterpret_cast<size_t>(value.impl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value有可能被重用,这里有极低的风险会出现同一个地址被两个value复用,比如前面的value析构了,被重用了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 OP 不析构的前提下是不是没有这个问题?CSE 会在最后一起替换 OP,在分析的过程中不会析构任何 OP,因此是不是没有这个问题?

这里之前是自己维护了一个 id,后来考虑到这里可以复用 impl,如果不析构的前提下仍有复用的风险,这里我会重新维护一个 id

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来在整个CSEAnanlysis过程不会有这个问题


size_t GetOpResultId(const pir::Value& value) {
size_t value_id = 0;
for (auto& result : value.defining_op()->results()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value.defining_op() 可能为nullptr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不可能为 nullptr 的,调用 GetOpResultId 的前提是 !IsTerminateValue(value),这里保证了不会是 nullptr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

}
for (auto& value : op->operands_source()) {
if (!IsTerminateValue(value) &&
!GetOperationCanBeSafeToReplace(value.defining_op())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!GetOperationCanBeSafeToReplace(value.defining_op())) {
value.defining_op() && !GetOperationCanBeSafeToReplace(value.defining_op())) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,!IsTerminateValue(value) 保证了这一点~

@SigureMo SigureMo merged commit 54b315a into PaddlePaddle:develop May 16, 2024
31 checks passed
@SigureMo SigureMo deleted the pir/add-cse-pass branch May 16, 2024 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants