Skip to content

Commit

Permalink
Refine interface of hlir::framework::Instruction (#56123)
Browse files Browse the repository at this point in the history
* Refine interface of hlir::framework::Instruction

* fix client usage
  • Loading branch information
Candy2Tang authored Aug 10, 2023
1 parent 4569ae1 commit dbd9743
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ void Instruction::Run(
// }
}

std::string Instruction::DumpInstruction() {
std::string Instruction::DumpInstruction() const {
std::stringstream ss;
ss << "Instruction {" << std::endl;
for (size_t i = 0; i < fn_names_.size(); ++i) {
Expand Down
12 changes: 8 additions & 4 deletions paddle/cinn/hlir/framework/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,17 @@ class Instruction {

int size() { return fn_ptrs_.size(); }

std::string DumpInstruction();
std::string DumpInstruction() const;

std::vector<std::vector<std::string>> GetInArgs() { return in_args_; }
std::vector<std::vector<std::string>> GetOutArgs() { return out_args_; }
const std::vector<std::vector<std::string>>& GetInArgs() const {
return in_args_;
}
const std::vector<std::vector<std::string>>& GetOutArgs() const {
return out_args_;
}
void ClearInArgs() { in_args_.clear(); }
void ClearOutArgs() { out_args_.clear(); }
std::vector<std::string> GetFnNames() { return fn_names_; }
const std::vector<std::string>& GetFnNames() const { return fn_names_; }
void AddInArgs(const std::vector<std::string>& in_args) {
in_args_.push_back(in_args);
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/framework/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ void Program::Export(const std::vector<std::string>& persistent_vars,
int instplaceholder = writeplaceholder(4 * 3, insnum, f);
int findex = 0;
for (auto& ins : instrs_) {
auto in_args = ins->GetInArgs();
auto out_args = ins->GetOutArgs();
auto fn_names = ins->GetFnNames();
auto& in_args = ins->GetInArgs();
auto& out_args = ins->GetOutArgs();
auto& fn_names = ins->GetFnNames();
for (int i = 0; i < fn_names.size(); i++, findex++) {
std::vector<std::string> all_args(in_args[i].begin(), in_args[i].end());
all_args.insert(
Expand Down

0 comments on commit dbd9743

Please sign in to comment.