Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] fix onednn dialect name #60665

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ void PirInterpreter::BuildInstruction() {
CREATE_INSTR(PhiKernelInstruction);
}
#ifdef PADDLE_WITH_DNNL
} else if (op.dialect()->name() == "pd_onednn_kernel") {
} else if (op.dialect()->name() == "onednn_kernel") {
auto op_name = op.attributes()
.at("op_name")
.dyn_cast<::pir::StrAttribute>()
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#include "paddle/pir/core/value.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
Expand Down Expand Up @@ -81,7 +81,7 @@ using AttributeHandlerFn = std::function<pir::Attribute(
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT
#ifdef PADDLE_WITH_DNNL
constexpr char kOneDNNTargetDialectPrefix[] = "pd_onednn_op."; // NOLINT
constexpr char kOneDNNTargetDialectPrefix[] = "onednn_op."; // NOLINT
#endif
constexpr char kCustomOpDialectPrefix[] = "custom_op.";
constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ if(WITH_MKLDNN)
set(op_onednn_info_file_tmp ${op_onednn_info_file}.tmp)

set(onednn_op_namespace paddle,onednn,dialect)
set(onednn_dialect_name pd_onednn_op)
set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.h)
set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.cc)
set(onednn_dialect_name onednn_op)
set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/onednn_op.h)
set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/onednn_op.cc)
set(onednn_op_header_file_tmp ${onednn_op_header_file}.tmp)
set(onednn_op_source_file_tmp ${onednn_op_source_file}.tmp)

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class OneDNNKernelDialect : public pir::Dialect {
public:
explicit OneDNNKernelDialect(pir::IrContext* context);

static const char* name() { return "pd_onednn_kernel"; }
static const char* name() { return "onednn_kernel"; }

void PrintType(pir::Type type, std::ostream& os) const override;

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CustomKernelOp : public pir::Op<CustomKernelOp> {
class OneDNNPhiKernelOp : public pir::Op<OneDNNPhiKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.phi_kernel"; }
static const char *name() { return "onednn_kernel.phi_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand All @@ -72,7 +72,7 @@ class OneDNNPhiKernelOp : public pir::Op<OneDNNPhiKernelOp> {
class OneDNNMixedPhiKernelOp : public pir::Op<OneDNNMixedPhiKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.phi_mixed_kernel"; }
static const char *name() { return "onednn_kernel.phi_mixed_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand All @@ -84,7 +84,7 @@ class OneDNNMixedPhiKernelOp : public pir::Op<OneDNNMixedPhiKernelOp> {
class OneDNNLegacyKernelOp : public pir::Op<OneDNNLegacyKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.legacy_kernel"; }
static const char *name() { return "onednn_kernel.legacy_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand Down
32 changes: 16 additions & 16 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,20 +1149,20 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if (
op_info.backward_name
and op_info.op_phi_name[0] not in vjp_interface_black_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(
op_info, op_info_items
)

if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"]

# if op has custom vjp rule, then append a CustomVjpTrait to it
if (
op_info.op_phi_name[0] in custom_vjp_op_name_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_traits += ["paddle::dialect::CustomVjpTrait"]

Expand All @@ -1184,7 +1184,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if op_name[-1] == "_":
op_traits += ["paddle::dialect::InplaceTrait"]

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_traits += ["paddle::dialect::OneDNNTrait"]

if op_info.is_onednn_only:
Expand All @@ -1208,7 +1208,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if (
op_name in decomp_interface_declare_gen_op_list
and kernel_func_name in decomp_interface_declare_gen_op_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_interfaces = op_interfaces + [
"paddle::dialect::DecompInterface"
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
build_func_with_muta_attr_is_input = ""

get_kernel_type_for_var_declare_str = ""
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
get_kernel_type_for_var_declare_str = (
get_kernel_type_for_var_declare_template
)
Expand Down Expand Up @@ -1607,7 +1607,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
origin_op_name=op_info.op_yaml_item['name'],
)

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
if len(op_info.onednn_extra_args) > 0:
args_name = []
for arg in op_info.onednn_extra_args:
Expand Down Expand Up @@ -1698,7 +1698,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):

# generate op GetKernelKeyForVar function str
op_get_kernel_type_for_var_str = ''
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_get_kernel_type_for_var_str = (
gen_kernel_type_for_var_str(
op_class_name,
Expand Down Expand Up @@ -1727,7 +1727,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
op_info.backward_name
and op_info.op_phi_name[0]
not in vjp_interface_black_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
Expand Down Expand Up @@ -1758,7 +1758,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
ops_defined_list.append(infer_symbolic_shape_define_str)

# NOTE(chenxi67)skip if dialect_name==cinn
if dialect_name == "cinn" or dialect_name == "pd_onednn_op":
if dialect_name == "cinn" or dialect_name == "onednn_op":
pass
else:
ops_vjp_defined_list.append(op_vjp_str)
Expand Down Expand Up @@ -1855,7 +1855,7 @@ def OpGenerator(
# (2) parse yaml files
op_compat_parser = OpCompatParser(op_compat_yaml_file)

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
with open(ops_onednn_extra_yaml_file, "r") as f:
ops_onednn_extra = yaml.safe_load(f)
ops_onednn_extra_map = {}
Expand Down Expand Up @@ -1890,7 +1890,7 @@ def OpGenerator(
op_info_items = {}
for op in op_yaml_items:
op_compat_item = None
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_compat_item = op_compat_parser.get_compat(op['name'])

if (
Expand All @@ -1916,7 +1916,7 @@ def OpGenerator(
) = op_compat_parser.parse_support_tensor(op)
op_compat_item['scalar'] = scalar_item
op_compat_item['int_array'] = int_array_item
if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
if first_file:
first_file = False
op["is_onednn_only"] = True
Expand All @@ -1934,7 +1934,7 @@ def OpGenerator(
all_op_info_items[op['name']] = item

op_infos.append(op_info_items)
if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_infos = [all_op_info_items]

# (3) auto code gen
Expand Down Expand Up @@ -2047,7 +2047,7 @@ def OpGenerator(
namespace=name, input=source_file_str
) # Add namespaces

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_def_h_file_tmp = (
"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"\n#include \""
+ op_def_h_file
Expand All @@ -2070,7 +2070,7 @@ def OpGenerator(
vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str)
if (
dialect_name != 'cinn'
and dialect_name != 'pd_onednn_op'
and dialect_name != 'onednn_op'
and op_vjp_cc_file
):
with open(op_vjp_cc_file, 'w') as f:
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class OneDNNOperatorDialect : public pir::Dialect {
public:
explicit OneDNNOperatorDialect(pir::IrContext* context);

static const char* name() { return "pd_onednn_op"; }
static const char* name() { return "onednn_op"; }

pir::Type ParseType(pir::IrParser& parser) override; // NOLINT
pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "paddle/utils/string/string_helper.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif

namespace paddle {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
#include "paddle/utils/flags.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
#endif

Expand Down Expand Up @@ -2219,7 +2219,7 @@ void ProcessBlock(
}
}
std::string target_op_name = op_item->name();
target_op_name.replace(0, 12, "pd_op");
target_op_name.replace(0, 9, "pd_op");
auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW("Ctx should have corresponding OpInfo %s", target_op_name);
Expand Down