Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add print pten kernel tool #39371

Merged
merged 12 commits into from
Feb 11, 2022
2 changes: 2 additions & 0 deletions paddle/pten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ add_subdirectory(kernels)
add_subdirectory(infermeta)
# pten operator definitions
add_subdirectory(ops)
# pten tools
add_subdirectory(tools)
# pten tests
add_subdirectory(tests)

Expand Down
72 changes: 64 additions & 8 deletions paddle/pten/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,81 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype));
}

// print kernel info with json format:
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
// "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
// "output": ["CPU, NCHW, complex64"],
// "attribute": ["i"]
// }
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否在这里注释放一个打印格式示例供参考

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

os << "InputNum(" << kernel.args_def().input_defs().size() << "): [";
// input
os << "{\"input\":[";
bool need_comma = false;
for (auto& in_def : kernel.args_def().input_defs()) {
os << "<" << in_def.backend << ", " << in_def.layout << ", " << in_def.dtype
<< ">";
if (need_comma) os << ",";
os << "\"" << in_def.backend << ", " << in_def.layout << ", "
<< in_def.dtype << "\"";
need_comma = true;
}
os << "]), AttributeNum(" << kernel.args_def().attribute_defs().size()
<< "), OutputNum(" << kernel.args_def().output_defs().size() << ")";
os << "],";

// output
os << "\"output\":[";
need_comma = false;
for (auto& out_def : kernel.args_def().output_defs()) {
if (need_comma) os << ",";
os << "\"" << out_def.backend << ", " << out_def.layout << ", "
<< out_def.dtype << "\"";
need_comma = true;
}
os << "],";

// attr
os << "\"attribute\":[";
need_comma = false;
for (auto& arg_def : kernel.args_def().attribute_defs()) {
if (need_comma) os << ",";
os << "\"" << arg_def.type_index.name() << "\"";
need_comma = true;
}
os << "]}";

return os;
}

// print all kernels info with json format:
// {
// "kernel_name1":
// [
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
// "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
// "output": ["CPU, NCHW, complex64"],
// "attribute": ["i"]
// },
// ...
// ],
// "kernel_name2": []
// ...
// }
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
os << "{";
bool need_comma_kernels = false;
for (const auto& op_kernel_pair : kernel_factory.kernels()) {
os << "- kernel name: " << op_kernel_pair.first << "\n";
if (need_comma_kernels) os << ",";
os << "\"" << op_kernel_pair.first << "\":[";
bool need_comma_per_kernel = false;
for (const auto& kernel_pair : op_kernel_pair.second) {
os << "\t- kernel key: " << kernel_pair.first << " | "
<< "kernel: " << kernel_pair.second << "\n";
if (need_comma_per_kernel) os << ",";
os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
need_comma_per_kernel = true;
}
os << "]";
need_comma_kernels = true;
}
os << "}";

return os;
}

Expand Down
8 changes: 8 additions & 0 deletions paddle/pten/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_executable(print_pten_kernels print_pten_kernels.cc)
target_link_libraries(print_pten_kernels pten pten_api_utils)
if(WIN32)
target_link_libraries(print_pten_kernels shlwapi.lib)
endif()
if(WITH_ROCM)
target_link_libraries(print_pten_kernels ${ROCM_HIPRTC_LIB})
endif()
24 changes: 24 additions & 0 deletions paddle/pten/tools/print_pten_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>

#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/declarations.h"

int main(int argc, char** argv) {
std::cout << pten::KernelFactory::Instance() << std::endl;
return 0;
}
28 changes: 28 additions & 0 deletions paddle/scripts/get_pten_kernel_function.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#=================================================
# Utils
#=================================================

set -e

EXIT_CODE=0;
tmp_dir=`mktemp -d`

PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../../" && pwd )"

unset GREP_OPTIONS && find ${PADDLE_ROOT}/paddle/pten/kernels -name "*.c*" | xargs sed -e '/PT_REGISTER_\(GENERAL_\)\?KERNEL(/,/)/!d' | awk 'BEGIN { RS="{" }{ gsub(/\n /,""); print $0 }' | grep PT_REGISTER | awk -F ",|\(" '{gsub(/ /,"");print $2, $3, $4, $5}' | sort -u | awk '{gsub(/pten::/,"");print $0}' | grep -v "_grad"
72 changes: 72 additions & 0 deletions paddle/scripts/get_pten_kernel_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/bin/python

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import yaml


def parse_args():
parser = argparse.ArgumentParser("gather pten kernel and infermate info")
parser.add_argument(
"--paddle_root_path",
type=str,
required=True,
help="root path of paddle src[WORK_PATH/Paddle] .")
parser.add_argument(
"--kernel_info_file",
type=str,
required=True,
help="kernel info file generated by get_pten_kernel_function.sh .")
args = parser.parse_args()
return args


def get_api_yaml_info(file_path):
f = open(file_path + "/python/paddle/utils/code_gen/api.yaml", "r")
cont = f.read()
return yaml.load(cont, Loader=yaml.FullLoader)


def get_kernel_info(file_path):
f = open(file_path, "r")
cont = f.readlines()
return [l.strip() for l in cont]


def merge(infer_meta_data, kernel_data):
meta_map = {}
for api in infer_meta_data:
if not api.has_key("kernel") or not api.has_key("infer_meta"):
continue
meta_map[api["kernel"]["func"]] = api["infer_meta"]["func"]
full_kernel_data = []
for l in kernel_data:
key = l.split()[0]
if meta_map.has_key(key):
full_kernel_data.append((l + " " + meta_map[key]).split())
else:
full_kernel_data.append((l + " unknown").split())

return full_kernel_data


if __name__ == "__main__":
args = parse_args()
infer_meta_data = get_api_yaml_info(args.paddle_root_path)
kernel_data = get_kernel_info(args.kernel_info_file)
out = merge(infer_meta_data, kernel_data)
print(json.dumps(out))