Skip to content

Commit

Permalink
[PIR]Support custom op in PIR (#59790)
Browse files Browse the repository at this point in the history
* support custom op in pir

* fix compile bugs

* fix bugs

* delete code

* fix windows bugs

* fix windows bugs

* add symbol to paddle lib

* fix windows bugs

* revert code

* fix bugs

* fix bugs

* perfect code according comment

* fix py3

* revert third party

* fix bugs

* fix bug

* fix compile bugs

* fix windows
  • Loading branch information
YuanRisheng authored Jan 2, 2024
1 parent 7c7446f commit cfad7d2
Show file tree
Hide file tree
Showing 32 changed files with 1,624 additions and 258 deletions.
42 changes: 42 additions & 0 deletions paddle/common/hash_funcs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

inline void HashCombine(std::size_t* seed) {}

// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template <typename T, typename... Rest>
inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
*seed *= 0x00000100000001B3;
HashCombine(seed, rest...);
}

// custom specialization of std::hash can be injected in namespace std
// ref: https://en.cppreference.com/w/cpp/utility/hash
namespace std {
template <typename T>
struct hash<std::vector<T>> {
std::size_t operator()(std::vector<T> const& vec) const noexcept {
std::size_t seed = 0xcbf29ce484222325;
for (auto val : vec) {
HashCombine(&seed, val);
}
return seed;
}
};
} // namespace std
28 changes: 27 additions & 1 deletion paddle/fluid/framework/custom_operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace framework {

constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT
namespace detail {

// dynamic lib load func
Expand Down Expand Up @@ -81,6 +82,31 @@ inline static bool IsMemberOf(const std::vector<std::string>& vec,
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}

inline static const OpMetaInfo& GetOpInfoByPirName(
const std::string& pir_op_name) {
auto custom_name = pir_op_name.substr(strlen(kCustomDialectPrefix));
int pos = custom_name.length();
if (custom_name.find("_grad_grad") != custom_name.npos) {
pos = custom_name.find("_grad_grad") + 1;
} else if (custom_name.find("_grad") != custom_name.npos) {
pos = custom_name.find("_grad") + 1;
}
auto custom_name_prefix = custom_name.substr(0, pos);
auto map_iter =
paddle::OpMetaInfoMap::Instance().GetMap().find(custom_name_prefix);
if (map_iter == paddle::OpMetaInfoMap::Instance().GetMap().end()) {
PADDLE_THROW("The info of custom op : " + custom_name + " is not exists!");
}
const auto& vec_op_meta = map_iter->second;
if (custom_name.find("_grad_grad") != custom_name.npos) {
return vec_op_meta[2];
} else if (custom_name.find("_grad") != custom_name.npos) {
return vec_op_meta[1];
} else {
return vec_op_meta[0];
}
}

} // namespace detail
} // namespace framework
} // namespace paddle
Loading

0 comments on commit cfad7d2

Please sign in to comment.