Skip to content

Commit 812d478

Browse files
committed
Simplify the classop translation
1 parent 0d0aeb3 commit 812d478

File tree

2 files changed

+23
-43
lines changed

2 files changed

+23
-43
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1010
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/IR/Attributes.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/IR/BuiltinTypes.h"
1415
#include "mlir/IR/Dialect.h"
@@ -1000,22 +1001,10 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
10001001
static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
10011002
CppEmitter::Scope classScope(emitter);
10021003
raw_indented_ostream &os = emitter.ostream();
1003-
os << "class " << classOp.getSymName() << " final {\n";
1004-
os << "public:\n\n";
1005-
1004+
os << "class " << classOp.getSymName() << " {\n";
1005+
os << "public:\n";
10061006
os.indent();
1007-
os << "const std::map<std::string, char*> _buffer_map {\n";
1008-
for (Operation &op : classOp) {
1009-
if (auto fieldOp = dyn_cast<FieldOp>(op))
1010-
os << " { \"" << fieldOp.getSymName() << "\", reinterpret_cast<char*>(&"
1011-
<< fieldOp.getAttrs() << ") },\n";
1012-
}
1013-
os << "};\n";
1014-
1015-
os << "char* getBufferForName(const std::string& name) const {\n";
1016-
os << " auto it = _buffer_map.find(name);\n";
1017-
os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n";
1018-
os << "}\n\n";
1007+
10191008
for (Operation &op : classOp) {
10201009
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
10211010
return failure();
@@ -1660,15 +1649,16 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
16601649
emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
16611650
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
16621651
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1663-
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1664-
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1665-
emitc::DivOp, emitc::ExpressionOp, emitc::FileOp, emitc::ForOp,
1666-
emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1667-
emitc::LoadOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1668-
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1669-
emitc::SubOp, emitc::SwitchOp, emitc::UnaryMinusOp,
1670-
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp,
1671-
emitc::ClassOp, emitc::FieldOp, emitc::GetFieldOp>(
1652+
emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp,
1653+
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
1654+
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
1655+
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
1656+
emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
1657+
emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
1658+
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1659+
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
1660+
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1661+
emitc::VerbatimOp>(
16721662

16731663
[&](auto op) { return printOperation(*this, op); })
16741664
// Func ops.

mlir/test/mlir-translate/emitc_classops.mlir

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,25 @@
11
// RUN: mlir-translate --mlir-to-cpp %s | FileCheck %s
22

33
emitc.class @modelClass {
4-
emitc.field @input_tensor : !emitc.array<1xf32>
5-
emitc.field @some_feature : !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}
4+
emitc.field @fieldName0 : !emitc.array<1xf32>
5+
emitc.field @fieldName1 : !emitc.array<1xf32>
66
emitc.func @execute() {
77
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
8-
%1 = get_field @input_tensor : !emitc.array<1xf32>
9-
%2 = get_field @some_feature : !emitc.array<1xf32>
8+
%1 = get_field @fieldName0 : !emitc.array<1xf32>
9+
%2 = get_field @fieldName1 : !emitc.array<1xf32>
1010
%3 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1111
return
1212
}
1313
}
1414

15-
// CHECK: class modelClass final {
15+
// CHECK: class modelClass {
1616
// CHECK-NEXT: public:
17-
// CHECK-EMPTY:
18-
// CHECK-NEXT: const std::map<std::string, char*> _buffer_map {
19-
// CHECK-NEXT: { "input_tensor", reinterpret_cast<char*>(&None) },
20-
// CHECK-NEXT: { "some_feature", reinterpret_cast<char*>(&{emitc.opaque = ["some_feature"]}) },
21-
// CHECK-NEXT: };
22-
// CHECK-NEXT: char* getBufferForName(const std::string& name) const {
23-
// CHECK-NEXT: auto it = _buffer_map.find(name);
24-
// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
25-
// CHECK-NEXT: }
26-
// CHECK-EMPTY:
27-
// CHECK-NEXT: float[1] input_tensor;
28-
// CHECK-NEXT: float[1] some_feature;
17+
// CHECK-NEXT: float[1] fieldName0;
18+
// CHECK-NEXT: float[1] fieldName1;
2919
// CHECK-NEXT: void execute() {
3020
// CHECK-NEXT: size_t v1 = 0;
31-
// CHECK-NEXT: float[1] v2 = input_tensor;
32-
// CHECK-NEXT: float[1] v3 = some_feature;
21+
// CHECK-NEXT: float[1] v2 = fieldName0;
22+
// CHECK-NEXT: float[1] v3 = fieldName1;
3323
// CHECK-NEXT: return;
3424
// CHECK-NEXT: }
3525
// CHECK-EMPTY:

0 commit comments

Comments
 (0)