Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… support_eager_hook2
  • Loading branch information
veyron95 committed Feb 17, 2022
2 parents ecefa5e + b72d4cb commit 5f3bff7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 33 deletions.
57 changes: 32 additions & 25 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,32 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
// Store the input variables whose buffer are not needed as
// attribute of the graph.
// Save lists of input variables, internal variables and output variables
// of the cluster as attributes of the subgraph for convenience.
auto collect_names_fn = [](
const GraphNodeSet& nodes,
const std::unordered_set<std::string>& ignore_names) {
auto result = std::make_unique<std::vector<std::string>>();
for (auto* node : nodes) {
if (ignore_names.count(node->Name())) {
continue;
}
result->emplace_back(node->Name());
}
return result;
};
subgraph->Set<std::vector<std::string>>(
kInternalVars, collect_names_fn(cluster_internals, {}).release());
subgraph->Set<std::vector<std::string>>(
kOutputVars, collect_names_fn(cluster_outputs, {}).release());
// Divide input variables into two parts: one is common and will be used
// in execution, the other may be empty and it is those variables whose
// buffer are not needed and only be used in graph symbolization
auto no_need_buffer_feeds = std::make_unique<std::unordered_set<std::string>>(
ExtractNoNeedBufferFeeds(cluster, cluster_inputs));
subgraph->Set<std::vector<std::string>>(
kInputVars,
collect_names_fn(cluster_inputs, *no_need_buffer_feeds).release());
subgraph->Set<std::unordered_set<std::string>>(
kNoNeedBufferFeeds, no_need_buffer_feeds.release());
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
Expand Down Expand Up @@ -458,33 +480,18 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp);

// Divide input variables as two parts:
// the ones that data buffer are not needed and remain ones
std::vector<std::string> op_kx_inputs, no_need_buffer_inputs;
const auto& subgraph =
CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto& no_need_buffer_feeds =
const auto& no_need_buffer_feeds =
subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
for (const auto* n : cluster_inputs) {
const auto& var_name = n->Name();
if (no_need_buffer_feeds.count(var_name)) {
no_need_buffer_inputs.emplace_back(var_name);
} else {
op_kx_inputs.emplace_back(var_name);
}
}

cinn_op_desc.SetInput(operators::kX, op_kx_inputs);
cinn_op_desc.SetInput(operators::kNoNeedBufferX, no_need_buffer_inputs);

std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) {
if (n->Var() != nullptr && !deny_var_set.count(n->Name())) {
output_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetOutput(operators::kOutputs, output_names);
cinn_op_desc.SetInput(operators::kX,
subgraph.Get<std::vector<std::string>>(kInputVars));
cinn_op_desc.SetInput(operators::kNoNeedBufferX,
std::vector<std::string>(no_need_buffer_feeds.begin(),
no_need_buffer_feeds.end()));
cinn_op_desc.SetOutput(operators::kOutputs,
subgraph.Get<std::vector<std::string>>(kOutputVars));
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ namespace framework {
namespace paddle2cinn {

constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds";
constexpr char kInputVars[] = "InputVars";
constexpr char kNoNeedBufferFeeds[] = "NoNeedBufferFeeds";
constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";

Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,19 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
ASSERT_EQ(no_need_buffer_feeds.size(), 2);
ASSERT_EQ(no_need_buffer_feeds,
std::unordered_set<std::string>({"var2", "var3"}));

// check the attributes of variable lists are saved correctly
ASSERT_TRUE(subgraph.Has(kInputVars));
EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInputVars),
std::vector<std::string>({"var1"}));
ASSERT_TRUE(subgraph.Has(kInternalVars));
EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInternalVars),
std::vector<std::string>({"var4"}));
ASSERT_TRUE(subgraph.Has(kOutputVars));
const auto& output_vars = subgraph.Get<std::vector<std::string>>(kOutputVars);
EXPECT_EQ(
std::unordered_set<std::string>(output_vars.begin(), output_vars.end()),
std::unordered_set<std::string>({"var5", "var6"}));
}

} // namespace paddle2cinn
Expand Down
28 changes: 23 additions & 5 deletions paddle/pten/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ enum class KernelType {

// TODO(chenweihang): support DataLayout and DataType selected
struct KernelKeySet {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};

BackendSet backend_set{Backend::UNDEFINED};
DataLayout layout{DataLayout::UNDEFINED};
DataType dtype{DataType::UNDEFINED};
Expand Down Expand Up @@ -97,9 +95,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
void operator()(const Tensor& x) {
key_set.backend_set = key_set.backend_set | detail::GetTensorBackendSet(x);
// TODO(chenweihang): selecte multi layout and dtype
if (pten::SelectedRows::classof(x.impl().get())) {
key_set.kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
key_set.layout = x.layout();
key_set.dtype = x.type();
dtype_set = dtype_set | DataTypeSet(x.dtype());
Expand All @@ -124,13 +119,36 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
}
};

struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
KernelType kernel_type{KernelType::DENSE_TENSOR_KENREL};

// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
void operator()(const Tensor& x) {
if (pten::SelectedRows::classof(x.impl().get())) {
kernel_type = KernelType::SELECTED_ROWS_KENREL;
}
}

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};

} // namespace detail

template <typename... Args>
KernelKeySet ParseKernelKeyByInputArgs(const Args&... args) {
return detail::KernelKeyParser().apply(args...).key_set;
}

template <typename... Args>
KernelType ParseKernelTypeByInputArgs(const Args&... args) {
return detail::KernelTypeParser().apply(args...).kernel_type;
}

DataType ParseDataType(DataType dtype);
DataType ParseDataType(const Tensor& tensor);
DataType ParseDataType(const std::vector<Tensor>& tensors);
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/utils/code_gen/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def gene_kernel_select(self) -> str:
if len(input_names) > 0:
if self.support_selected_rows_kernel:
kernel_select_code = kernel_select_code + f"""
KernelType kernel_type;
KernelType kernel_type = ParseKernelTypeByInputArgs({", ".join(input_names)});
"""

kernel_select_code = kernel_select_code + f"""
Expand All @@ -354,7 +354,6 @@ def gene_kernel_select(self) -> str:
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
{'kernel_type = kernel_key_set.kernel_type;' if self.support_selected_rows_kernel else ''}
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
Expand Down

0 comments on commit 5f3bff7

Please sign in to comment.