Skip to content

Commit

Permalink
fix styles, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Oct 21, 2019
1 parent 6150ae0 commit d2c20df
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 122 deletions.
3 changes: 1 addition & 2 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ struct Argument {

DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode,
AnalysisConfig::Precision);

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program()));
}
if (pass_name == "lite_subgraph_pass") {
bool enable_int8 = argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
bool enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("lite_ops_filter",
Expand Down
151 changes: 80 additions & 71 deletions paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
#include <unordered_set>
#include <vector>

#include <iostream>
#include <fstream>
#include <iostream>

#include "paddle/fluid/inference/lite/op_teller.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/lite/op_teller.h"
#include "paddle/fluid/inference/utils/singleton.h"

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/string/pretty_log.h"

#include "paddle/fluid/inference/lite/engine.h"
Expand All @@ -43,8 +43,8 @@ using framework::ir::Node;

namespace lite {

std::string UniqueKey(const std::vector<std::string> &engine_inputs,
const std::vector<std::string> &engine_outputs,
std::string UniqueKey(const std::vector<std::string>& engine_inputs,
const std::vector<std::string>& engine_outputs,
const std::string& id) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
Expand All @@ -60,8 +60,8 @@ std::string UniqueKey(const std::vector<std::string> &engine_inputs,

std::vector<std::string> IOVarsFilter(const std::vector<Node*>& nodes) {
std::set<std::string> names;
for (const auto& node: nodes) {
if (node->IsVar() && !node->Var()->Persistable()) {
for (const auto& node : nodes) {
if (node->IsVar() && !node->Var()->Persistable()) {
names.insert(node->Name());
}
}
Expand All @@ -75,19 +75,20 @@ void StrToBinaryFile(const std::string& path, const std::string& str) {
}

void ModifyHostProgram(framework::ProgramDesc* host_program,
framework::BlockDesc* host_sub_block,
const std::unordered_set<Node *>& io_var_nodes,
const std::vector<framework::OpDesc*>& subgraph_ops) {
for (auto *var_node: io_var_nodes) {
framework::BlockDesc* host_sub_block,
const std::unordered_set<Node*>& io_var_nodes,
const std::vector<framework::OpDesc*>& subgraph_ops) {
for (auto* var_node : io_var_nodes) {
auto* sub_block_var = host_sub_block->Var(var_node->Name());
sub_block_var->Proto()->CopyFrom(*var_node->Var()->Proto());
}
for (auto *op_desc : subgraph_ops) {
for (auto* op_desc : subgraph_ops) {
auto* sub_block_op = host_sub_block->AppendOp();
sub_block_op->CopyFrom(*op_desc);
if (op_desc->HasAttr("sub_block")) {
int32_t global_sub_id = host_sub_block->ID();
auto *op_sub_block = host_program->MutableBlock(op_desc->GetBlockAttrId("sub_block"));
auto* op_sub_block =
host_program->MutableBlock(op_desc->GetBlockAttrId("sub_block"));
op_sub_block->Proto()->set_parent_idx(global_sub_id);
}
}
Expand All @@ -97,21 +98,22 @@ void ModifyHostProgram(framework::ProgramDesc* host_program,
// (initial) -> proto::desc (flush) -> framework::desc (final).
// Ir::Graph is limited to changing the main block, so the sub block
// needs to be processed here.
void ModifyEngineProgram(Node *merged_node,
framework::ProgramDesc* host_program,
framework::ProgramDesc* engine_program,
framework::BlockDesc* host_sub_block,
const std::unordered_set<Node *>& io_var_nodes,
const std::vector<framework::OpDesc*>& subgraph_ops) {

void ModifyEngineProgram(Node* merged_node,
framework::ProgramDesc* host_program,
framework::ProgramDesc* engine_program,
framework::BlockDesc* host_sub_block,
const std::unordered_set<Node*>& io_var_nodes,
const std::vector<framework::OpDesc*>& subgraph_ops) {
// 1. Fill the main block of lite program.
framework::BlockDesc* engine_global_block = engine_program->MutableBlock(framework::kRootBlockIndex);
framework::BlockDesc* engine_global_block =
engine_program->MutableBlock(framework::kRootBlockIndex);
PrependFeedOps(engine_global_block, IOVarsFilter(merged_node->inputs));
for (auto *var_node: io_var_nodes) {
framework::VarDesc* sub_block_var = engine_global_block->Var(var_node->Name());
for (auto* var_node : io_var_nodes) {
framework::VarDesc* sub_block_var =
engine_global_block->Var(var_node->Name());
sub_block_var->Proto()->CopyFrom(*var_node->Var()->Proto());
}
for (auto *op_desc : subgraph_ops) {
for (auto* op_desc : subgraph_ops) {
auto* sub_block_op = engine_global_block->AppendOp();
sub_block_op->CopyFrom(*op_desc);
}
Expand All @@ -123,18 +125,19 @@ void ModifyEngineProgram(Node *merged_node,
sub_blocks_map[host_sub_block->ID()] = framework::kRootBlockIndex;
std::function<void(const std::vector<framework::OpDesc*>&)> append_sub_blocks;
append_sub_blocks = [&](const std::vector<framework::OpDesc*>& ops) {
for (auto *op_desc : ops) {
for (auto* op_desc : ops) {
if (op_desc->HasAttr("sub_block")) {
int32_t host_op_sub_id = op_desc->GetBlockAttrId("sub_block");
if (copied_host_ids.count(host_op_sub_id)) continue;
size_t engine_block_size = engine_program->Size();
auto* host_op_sub_block = host_program->MutableBlock(host_op_sub_id);
auto* engine_op_sub_block = engine_program->AppendBlock(*(op_desc->Block()));
for (auto* var: host_op_sub_block->AllVars()) {
auto* engine_op_sub_block =
engine_program->AppendBlock(*(op_desc->Block()));
for (auto* var : host_op_sub_block->AllVars()) {
auto* engine_var = engine_op_sub_block->Var(var->Name());
engine_var->Proto()->CopyFrom(*var->Proto());
}
for (auto* op: host_op_sub_block->AllOps()) {
for (auto* op : host_op_sub_block->AllOps()) {
auto* engine_op = engine_op_sub_block->AppendOp();
engine_op->Proto()->CopyFrom(*op->Proto());
}
Expand All @@ -145,7 +148,7 @@ void ModifyEngineProgram(Node *merged_node,
};
append_sub_blocks(subgraph_ops);
for (size_t i = 0; i < engine_program->Size(); i++) {
for (auto *op_desc : engine_program->Block(i).AllOps()) {
for (auto* op_desc : engine_program->Block(i).AllOps()) {
if (op_desc->HasAttr("sub_block")) {
int32_t id = op_desc->GetBlockAttrId("sub_block");
op_desc->SetAttr("sub_block", sub_blocks_map[id]);
Expand All @@ -154,60 +157,63 @@ void ModifyEngineProgram(Node *merged_node,
}
}

void OrganizeProgram(Node *merged_node,
framework::ProgramDesc* host_program,
framework::ProgramDesc* engine_program,
std::vector<std::string> *repetitive_params) {
std::vector<framework::ir::Node *>& subgraph = *Agent(merged_node).subgraph();
void OrganizeProgram(Node* merged_node, framework::ProgramDesc* host_program,
framework::ProgramDesc* engine_program,
std::vector<std::string>* repetitive_params) {
std::vector<framework::ir::Node*>& subgraph = *Agent(merged_node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());

const framework::BlockDesc &host_global_block = host_program->Block(framework::kRootBlockIndex);
framework::BlockDesc* host_sub_block = host_program->AppendBlock(host_global_block);
const framework::BlockDesc& host_global_block =
host_program->Block(framework::kRootBlockIndex);
framework::BlockDesc* host_sub_block =
host_program->AppendBlock(host_global_block);

string::PrettyLogDetail("--- detect a sub-graph with %d nodes",
subgraph.size());

std::unordered_set<Node *> io_var_nodes = GetRelatedIOVarNodes(subgraph);
for (const auto* node: io_var_nodes) {
std::unordered_set<Node*> io_var_nodes = GetRelatedIOVarNodes(subgraph);
for (const auto* node : io_var_nodes) {
LOG(INFO) << "IO Variable Name: " << node->Name();
}

std::vector<framework::OpDesc*> subgraph_ops;
for (auto *op_node : subgraph) {
for (auto* op_node : subgraph) {
subgraph_ops.push_back(op_node->Op());
}

ModifyHostProgram(host_program, host_sub_block, io_var_nodes, subgraph_ops);
ModifyEngineProgram(merged_node, host_program, engine_program, host_sub_block,
io_var_nodes, subgraph_ops);
*repetitive_params = ExtractParameters(io_var_nodes);
for (const auto& param: *repetitive_params) {
for (const auto& param : *repetitive_params) {
LOG(INFO) << "Repetitive param: " << param;
}

host_program->Flush();
engine_program->Flush();
}
} // namespace lite
} // namespace lite

void LiteSubgraphPass::SetUpEngine(framework::ProgramDesc* program,
const std::vector<std::string>& repetitive_params,
const std::string& unique_key, bool dump_model) const {
void LiteSubgraphPass::SetUpEngine(
framework::ProgramDesc* program,
const std::vector<std::string>& repetitive_params,
const std::string& unique_key, bool dump_model) const {
inference::lite::EngineConfig config;
auto *scope = param_scope();
auto* scope = param_scope();

// When the pass is started, only the persistent variables of the
// main block are read. Fluid seems to allow persistence variables
// in the sub block, but they are controlled by context, so the
// support is suspended here.
auto serialize_params = [] (std::string* str, framework::Scope* scope,
const std::vector<std::string>& params) {
auto serialize_params = [](std::string* str, framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
platform::CPUDeviceContext ctx;
for (const auto& param: params) {
for (const auto& param : params) {
LOG(INFO) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(scope->FindVar(param), "Block should already have a '%s' variable",
param);
PADDLE_ENFORCE_NOT_NULL(scope->FindVar(param),
"Block should already have a '%s' variable",
param);
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
Expand All @@ -217,7 +223,8 @@ void LiteSubgraphPass::SetUpEngine(framework::ProgramDesc* program,
bool use_gpu = Get<bool>("use_gpu");
bool enable_int8 = Get<bool>("enable_int8");
lite_api::TargetType target_type = use_gpu ? TARGET(kCUDA) : TARGET(kHost);
paddle::lite_api::PrecisionType precision_type = enable_int8 ? PRECISION(kInt8) : PRECISION(kFloat);
paddle::lite_api::PrecisionType precision_type =
enable_int8 ? PRECISION(kInt8) : PRECISION(kFloat);
paddle::lite::Place prefer_place = {target_type, precision_type};

serialize_params(&config.param, scope, repetitive_params);
Expand All @@ -233,65 +240,67 @@ void LiteSubgraphPass::SetUpEngine(framework::ProgramDesc* program,
lite::StrToBinaryFile("./model.bin", config.model);
lite::StrToBinaryFile("./param.bin", config.param);
}
inference::Singleton<inference::lite::EngineManager>::Global()
.Create(unique_key, config);
inference::Singleton<inference::lite::EngineManager>::Global().Create(
unique_key, config);
}

void LiteSubgraphPass::BuildOperator(
Node *merged_node, framework::ProgramDesc* global_program,
std::vector<std::string> *repetitive_params) const {

Node* merged_node, framework::ProgramDesc* global_program,
std::vector<std::string>* repetitive_params) const {
framework::ProgramDesc engine_program;

const std::string id = std::to_string(Get<int>("predictor_id"));
const std::vector<std::string> input_names = lite::IOVarsFilter(merged_node->inputs);
const std::vector<std::string> output_names = lite::IOVarsFilter(merged_node->outputs);
const std::vector<std::string> input_names =
lite::IOVarsFilter(merged_node->inputs);
const std::vector<std::string> output_names =
lite::IOVarsFilter(merged_node->outputs);
const std::string unique_key = lite::UniqueKey(input_names, output_names, id);

lite::OrganizeProgram(merged_node, global_program, &engine_program, repetitive_params);
lite::OrganizeProgram(merged_node, global_program, &engine_program,
repetitive_params);
SetUpEngine(&engine_program, *repetitive_params, unique_key);

auto *op_desc = merged_node->Op();
auto* op_desc = merged_node->Op();
op_desc->SetInput("Xs", input_names);
op_desc->SetOutput("Ys", output_names);
op_desc->SetType("lite_engine");
op_desc->SetAttr("engine_key", unique_key);
}

void LiteSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {

void LiteSubgraphPass::ApplyImpl(framework::ir::Graph* graph) const {
framework::ir::FusePassBase::Init("lite_subgraph_pass", graph);
framework::ProgramDesc* global_program = Get<framework::ProgramDesc *>("program");
framework::ProgramDesc* global_program =
Get<framework::ProgramDesc*>("program");

auto &lite_ops_filter = Get<std::vector<std::string>>("lite_ops_filter");
auto& lite_ops_filter = Get<std::vector<std::string>>("lite_ops_filter");

auto teller = [&lite_ops_filter](const Node *node) {
auto teller = [&lite_ops_filter](const Node* node) {
if (!node->IsOp() || !node->Op())
return false;
else if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch")
return false;
else if (std::find(lite_ops_filter.begin(), lite_ops_filter.end(),
node->Op()->Type()) != lite_ops_filter.end())
return false;
return inference::lite::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
return inference::lite::OpTeller::Global().Tell(node->Op()->Type(),
*node->Op());
};

SubGraphFuser fuser(graph, teller, 0 /* min_subgraph_size */, "lite_engine");
fuser();

std::vector<std::string> repetitive_params;
for (auto *node : graph->Nodes()) {
for (auto* node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) {
BuildOperator(node, global_program, &repetitive_params);
std::unordered_set<const Node *> nodes2remove(
std::unordered_set<const Node*> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
}
}

std::unordered_set<const Node *> nodes2remove;
for (auto *node : graph->Nodes()) {
std::unordered_set<const Node*> nodes2remove;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && Agent(node).deleted()) {
nodes2remove.insert(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ namespace analysis {

class LiteSubgraphPass : public framework::ir::FusePassBase {
public:
void ApplyImpl(framework::ir::Graph *graph) const override;
void ApplyImpl(framework::ir::Graph* graph) const override;

private:
void BuildOperator(framework::ir::Node *merged_node, framework::ProgramDesc* global_program,
std::vector<std::string> *repetitive_params) const;
void BuildOperator(framework::ir::Node* merged_node,
framework::ProgramDesc* global_program,
std::vector<std::string>* repetitive_params) const;

void SetUpEngine(framework::ProgramDesc* program,
const std::vector<std::string>& repetitive_params,
Expand Down
18 changes: 9 additions & 9 deletions paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,29 @@ std::unordered_set<Node *> GetRelatedIOVarNodes(
return io_nodes;
}

void PrependFeedOps(framework::BlockDesc* global_block,
const std::vector<std::string>& feed_target_names,
void PrependFeedOps(framework::BlockDesc *global_block,
const std::vector<std::string> &feed_target_names,
std::string feed_holder_name) {
framework::VarDesc* feed_var = global_block->Var(feed_holder_name);
framework::VarDesc *feed_var = global_block->Var(feed_holder_name);
feed_var->SetType(paddle::framework::proto::VarType::FEED_MINIBATCH);
feed_var->SetPersistable(true);
for (size_t i = 0; i < feed_target_names.size(); i++) {
framework::OpDesc* feed_op = global_block->AppendOp();
framework::OpDesc *feed_op = global_block->AppendOp();
feed_op->SetType("feed");
feed_op->SetInput("X", {feed_holder_name});
feed_op->SetOutput("Out", {feed_target_names[i]});
feed_op->SetAttr("col", static_cast<int>(i));
}
}

void PrependFetchOps(framework::BlockDesc* global_block,
const std::vector<std::string>& fetch_target_names,
std::string fetch_holder_name) {
framework::VarDesc* fetch_var = global_block->Var(fetch_holder_name);
void PrependFetchOps(framework::BlockDesc *global_block,
const std::vector<std::string> &fetch_target_names,
std::string fetch_holder_name) {
framework::VarDesc *fetch_var = global_block->Var(fetch_holder_name);
fetch_var->SetType(paddle::framework::proto::VarType::FETCH_LIST);
fetch_var->SetPersistable(true);
for (size_t i = 0; i < fetch_target_names.size(); i++) {
framework::OpDesc* fetch_op = global_block->AppendOp();
framework::OpDesc *fetch_op = global_block->AppendOp();
fetch_op->SetType("fetch");
fetch_op->SetInput("X", {fetch_target_names[i]});
fetch_op->SetOutput("Out", {fetch_holder_name});
Expand Down
Loading

0 comments on commit d2c20df

Please sign in to comment.