Skip to content
Open
252 changes: 181 additions & 71 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto memRefType = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
Expand All @@ -80,12 +80,47 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
op.getLoc(), "cannot transform alloca with alignment requirement");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
if (op.getType().getRank() == 0 ||
llvm::is_contained(memRefType.getShape(), 0)) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with rank 0 or zero-sized dim");
}

auto convertedTy = getTypeConverter()->convertType(memRefType);
if (!convertedTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert memref type");
}

auto arrayTy = emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
auto elemTy = memRefType.getElementType();

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
auto arrayVar =
emitc::VariableOp::create(rewriter, op.getLoc(), arrayTy, noInit);

// Build zero indices for the base subscript.
SmallVector<Value> indices;
for (unsigned i = 0; i < memRefType.getRank(); ++i) {
auto zero = emitc::ConstantOp::create(rewriter,
op.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(0));
indices.push_back(zero);
}

auto current = emitc::SubscriptOp::create(rewriter,
op.getLoc(), emitc::LValueType::get(elemTy), arrayVar.getResult(),
indices);

auto ptrElemTy = emitc::PointerType::get(elemTy);
auto addrOf = emitc::AddressOfOp::create(rewriter, op.getLoc(), ptrElemTy,
current.getResult());

auto ptrArrayTy = emitc::PointerType::get(arrayTy);
auto casted = emitc::CastOp::create(rewriter,op.getLoc(), ptrArrayTy,
addrOf.getResult());

rewriter.replaceOp(op, casted.getResult());
return success();
}
};
Expand Down Expand Up @@ -122,24 +157,6 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}

static emitc::AddressOfOp
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {

emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
builder, loc, builder.getIndexType(), builder.getIndexAttr(0));

emitc::ArrayType arrayType = arrayValue.getType();
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
subPtr);

return ptr;
}

struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -194,8 +211,9 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);

emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
emitc::ArrayType arrayType =
emitc::ArrayType::get(memrefType.getShape(), elementType);
emitc::PointerType targetPointerType = emitc::PointerType::get(arrayType);
emitc::CastOp castOp = emitc::CastOp::create(
rewriter, loc, targetPointerType, allocCall.getResult(0));

Expand Down Expand Up @@ -223,20 +241,10 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
return rewriter.notifyMatchFailure(
loc, "incompatible target memref type for EmitC conversion");

auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
emitc::AddressOfOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);

auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
emitc::AddressOfOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);

emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
rewriter, loc, TypeRange{}, "memcpy",
ValueRange{
targetPtr.getResult(), srcPtr.getResult(),
operands.getTarget(), operands.getSource(),
calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});

rewriter.replaceOp(copyOp, memCpyCall.getResults());
Expand Down Expand Up @@ -264,11 +272,14 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
"currently not supported");
}

Type resultTy = convertMemRefType(opTy, getTypeConverter());

if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
Type elemTy = getTypeConverter()->convertType(opTy.getElementType());
Type globalType;
if (opTy.getRank() == 0) {
globalType = elemTy;
} else {
SmallVector<int64_t> shape(opTy.getShape().begin(),
opTy.getShape().end());
globalType = emitc::ArrayType::get(shape, elemTy);
}

SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
Expand All @@ -292,7 +303,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
initialValue = {};

rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
op, operands.getSymName(), globalType, initialValue, externSpecifier,
staticSpecifier, operands.getConstant());
return success();
}
Expand All @@ -307,48 +318,147 @@ struct ConvertGetGlobal final
ConversionPatternRewriter &rewriter) const override {

MemRefType opTy = op.getType();
Location loc = op.getLoc();

Type elemTy = getTypeConverter()->convertType(opTy.getElementType());
if (!elemTy)
return rewriter.notifyMatchFailure(loc, "cannot convert element type");

Type resultTy = convertMemRefType(opTy, getTypeConverter());

if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
Type globalType;
if (opTy.getRank() == 0) {
globalType = elemTy;
} else {
SmallVector<int64_t> shape(opTy.getShape().begin(),
opTy.getShape().end());
globalType = emitc::ArrayType::get(shape, elemTy);
}
Comment on lines +333 to +340
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this formatted correctly? seems like the indentation is off...


if (opTy.getRank() == 0) {
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
emitc::LValueType lvalueType = emitc::LValueType::get(globalType);
emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
globalLValue);
emitc::PointerType pointerType = emitc::PointerType::get(globalType);
auto addrOf = emitc::AddressOfOp::create(rewriter,
loc, pointerType, globalLValue.getResult());

auto arrayTy = emitc::ArrayType::get({1}, globalType);
auto ptrArrayTy = emitc::PointerType::get(arrayTy);
auto casted =
emitc::CastOp::create(rewriter,loc, ptrArrayTy, addrOf.getResult());
rewriter.replaceOp(op, casted.getResult());
return success();
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
operands.getNameAttr());

auto getGlobal = emitc::GetGlobalOp::create(rewriter,
loc, globalType, operands.getNameAttr());

SmallVector<Value> indices;
for (unsigned i = 0; i < opTy.getRank(); ++i) {
auto zero = emitc::ConstantOp::create(rewriter,
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
indices.push_back(zero);
}

auto current = emitc::SubscriptOp::create(rewriter,
loc, emitc::LValueType::get(elemTy), getGlobal.getResult(), indices);

auto ptrElemTy = emitc::PointerType::get(opTy.getElementType());
auto addrOf = emitc::AddressOfOp::create(rewriter,
loc, ptrElemTy, current.getResult());

auto casted =
emitc::CastOp::create(rewriter, loc, resultTy, addrOf.getResult());

rewriter.replaceOp(op, casted.getResult());
return success();
}
};


// Helper to compute a flattened linear index for multi-dimensional memrefs
// and generate a single subscript access in EmitC.

static Value getFlattenedSubscript(ConversionPatternRewriter &rewriter,
Location loc,
Value memrefVal,
ValueRange indices,
Type elementTy) {
auto module = memrefVal.getDefiningOp() ? memrefVal.getDefiningOp()->getParentOfType<ModuleOp>()
: rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();

// Inject mt_index template once per module to compute flattened indices.
if (module && !module->getAttr("emitc.macros_inserted")) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
// The template is used to avoid emitting repeated
// index arithmetic and keeps the generated C/C++ code readable and reusable.
std::string templateDef =
"\n/* Generalized Indexing Template */\n"
"template <typename T> constexpr T mt_index(T i_last) { return i_last; }\n"
"template <typename T, typename... Args>\n"
"constexpr T mt_index(T idx, T stride, Args... rest) {\n"
" return (idx * stride) + mt_index(rest...);\n"
"}\n";

emitc::VerbatimOp::create(rewriter, loc, rewriter.getStringAttr(templateDef));
module->setAttr("emitc.macros_inserted", rewriter.getUnitAttr());
}

auto ptrTy = cast<emitc::PointerType>(memrefVal.getType());
auto arrayTy = cast<emitc::ArrayType>(ptrTy.getPointee());
ArrayRef<int64_t> shape = arrayTy.getShape();
unsigned rank = indices.size();

// Compute static row-major strides from the array shape.
SmallVector<int64_t> strideValues(rank, 1);
for (int i = (int)rank - 2; i >= 0; --i) {
strideValues[i] = strideValues[i + 1] * shape[i + 1];
}
// build the argument list (index, stride, …) used to invoke it for a given
// memref access.
SmallVector<Value> macroArgs;
for (unsigned i = 0; i < rank; ++i) {
macroArgs.push_back(indices[i]);
if (i < rank - 1) {
auto sVal = emitc::ConstantOp::create(rewriter, loc,
rewriter.getIndexType(),
rewriter.getIndexAttr(strideValues[i]));
macroArgs.push_back(sVal.getResult());
}
}

auto flatIndex = emitc::CallOpaqueOp::create(rewriter, loc,
rewriter.getIndexType(),
"mt_index", macroArgs);

auto elemPtrTy = emitc::PointerType::get(elementTy);
auto flatPtr = emitc::CastOp::create(rewriter, loc, elemPtrTy, memrefVal);
auto lvalueTy = emitc::LValueType::get(elementTy);
auto subscript = emitc::SubscriptOp::create(rewriter, loc,
lvalueTy, flatPtr.getResult(),
flatIndex.getResult(0));

return subscript.getResult();
}
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
if (!resultTy) return failure();

auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());
Value subscript = getFlattenedSubscript(rewriter, op.getLoc(),
operands.getMemref(),
operands.getIndices(),
resultTy);

rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
return success();
Expand All @@ -361,16 +471,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
Value valueToStore = operands.getValue();
Type elementTy = valueToStore.getType();
Value subscript = getFlattenedSubscript(rewriter, op.getLoc(),
operands.getMemref(),
operands.getIndices(),
elementTy);
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, valueToStore);
return success();
}
};
Expand All @@ -386,8 +494,10 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.convertType(memRefType.getElementType());
if (!convertedElementType)
return {};
return emitc::ArrayType::get(memRefType.getShape(),
convertedElementType);
Type innerArrayType =
emitc::ArrayType::get(memRefType.getShape(), convertedElementType);
return emitc::PointerType::get(innerArrayType);

});

auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
Expand Down
Loading