Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Simple lowering to LLVM #205

Open
doru1004 opened this issue Oct 21, 2019 · 6 comments
Open

Simple lowering to LLVM #205

doru1004 opened this issue Oct 21, 2019 · 6 comments

Comments

@doru1004
Copy link

doru1004 commented Oct 21, 2019

In a previous version of the toy example the lowering of the PrintOp to LLVM dialect was done differently. In that version the inner loop nest was generated as:

LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)), iOp(i, j)});
});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
});

This led to the following code being emitted:

  affine.for %arg0 = 0 to 2 {
  affine.for %arg1 = 0 to 2 {
    %11 = llvm.extractvalue %7[0 : index] : memref<4xi8>
    %12 = affine.load %6[%arg0, %arg1] : memref<2x2xf64>
    %13 = llvm.call @printf(%11, %12) : (!llvm<"i8*">, f64) -> !llvm.i32
  }
  %9 = llvm.extractvalue %8[0 : index] : memref<2xi8>
  %10 = llvm.call @printf(%9) : (!llvm<"i8*">) -> !llvm.i32
}

This code appears to not be fully correct. When invoking the run() method the verifier complains:

loc("example.toy":11:3): error: 'llvm.extractvalue' op operand #0 must be LLVM dialect type, but got 'memref<4xi8>

I realize this is not a full lowering to LLVM dialect but I would expect the code to be correct at every intermediate stage, including this one (similar to LLVM IR code gen principles).

My questions are:

  1. Was the passing of the memref to extractvalue function intentional or is it a bug?
  2. Should the memref type have been lowered to an LLVM type OR is there an error with code inside the PrintOpConversion::matchAndRewrite() function instead?
@joker-eph
Copy link
Contributor

I would expect the code to be correct at every intermediate stage, including this one (similar to LLVM IR code gen principles).

The LLVM IR principle (which we follow) is that between every pass the IR must be consistent. We can't necessarily have the IR in a "correct" state at every point inside a pass (Every LLVM pass does that). This is why the lowering pattern you are referring to was defined with the pass using it, and this pass was also doing more work to get down to LLVM dialect entirely. I don't think any MemRef was left after the pass completes?

To answer your question:

  1. This was an intentional bug ;-)
    At that time we needed to wrap up the tutorial quickly before the conference, and the conversion framework was still in early stages: we didn't have yet a good solution to handle type conversions.
    Problems like this one in the tutorial is one of the reason we rewrote it over the last two months.

  2. For partial lowering we need to explicitly handle type conversions, by inserting explicit cast for instance.

@doru1004
Copy link
Author

doru1004 commented Oct 23, 2019

Thanks for the answer! :)

My confusion stemmed from the fact that I was expecting each instruction that is emitted to be emitted with a valid type. In this case I was hoping that the llvm.extractvalue instruction will trigger the conversion of its input Memref type to an LLVM type.

I don't think any MemRef was left after the pass completes?

That's right. The pass was indeed lowering too much. When following the tutorial I kind of took it apart a bit to see what was going on at each intermediate step (even within a pass). I wanted to transform the pass to lower just the print ops to LLVM whilst keeping everything else in place. I tried various combinations but was not able to get valid code at the end because I kept running into type errors like the one above.

For partial lowering we need to explicitly handle type conversions, by inserting explicit cast for instance

Agreed.

Update:

So I tried the new tutorial with the aim to modify it to only lower the printf function to LLVM and none of the other code in the module. To this effect I removed all other lowering patterns except the PrintOpLowering one. I also modified the conversion from Full to Partial:

void ToyToLLVMLoweringPass::runOnModule() {
  ConversionTarget target(getContext());
  target.addLegalDialect<AffineOpsDialect, StandardOpsDialect, LLVM::LLVMDialect>();
  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
  LLVMTypeConverter typeConverter(&getContext());
  OwningRewritePatternList patterns;
  patterns.insert<PrintOpLowering>(&getContext());
  if (failed(applyPartialConversion(module, target, patterns, &typeConverter)))
    signalPassFailure();
}

I also made some modifications to the matchAndRewrite() function of the PrintOpLowering class to account for this change.

In the end, although the matchAndRewrite function completes successfully I am seeing no changes happening to the module code. The code looks exactly like the code I started with as if the pattern was not applied. No error is thrown in the process so I have no idea what it is that I'm missing.

@River707
Copy link
Contributor

River707 commented Oct 23, 2019

Thanks for the answer! :)

My confusion stemmed from the fact that I was expecting each instruction that is emitted to be emitted with a valid type. In this case I was hoping that the llvm.extractvalue instruction will trigger the conversion of its input Memref type to an LLVM type.

The conversion operates on operations topologically, meaning that the inputs to an operation should have already been converted before the operation is encountered.

I don't think any MemRef was left after the pass completes?

That's right. The pass was indeed lowering too much. When following the tutorial I kind of took it apart a bit to see what was going on at each intermediate step (even within a pass). I wanted to transform the pass to lower just the print ops to LLVM whilst keeping everything else in place. I tried various combinations but was not able to get valid code at the end because I kept running into type errors like the one above.

Yes, if the input operations haven't been converted; you will need to explicitly handle them yourselves.

For partial lowering we need to explicitly handle type conversions, by inserting explicit cast for instance

Agreed.

Update:

So I tried the new tutorial with the aim to modify it to only lower the printf function to LLVM and none of the other code in the module. To this effect I removed all other lowering patterns except the PrintOpLowering one. I also modified the conversion from Full to Partial:

void ToyToLLVMLoweringPass::runOnModule() {
  ConversionTarget target(getContext());
  target.addLegalDialect<AffineOpsDialect, StandardOpsDialect, LLVM::LLVMDialect>();
  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
  LLVMTypeConverter typeConverter(&getContext());
  OwningRewritePatternList patterns;
  patterns.insert<PrintOpLowering>(&getContext());
  if (failed(applyPartialConversion(module, target, patterns, &typeConverter)))
    signalPassFailure();
}

I also made some modifications to the matchAndRewrite() function of the PrintOpLowering class to account for this change.

In the end, although the matchAndRewrite function completes successfully I am seeing no changes happening to the module code. The code looks exactly like the code I started with as if the pattern was not applied. No error is thrown in the process so I have no idea what it is that I'm missing.

Without seeing the modifications it may be difficult to know what changes you have made and why it specifically isn't working. For example, are you inserting a cast to convert from MemRef to the corresponding LLVMType? Is that operation marked legal? The debugging experience definitely needs to be improved, and we have some ideas for that, but you can mark the PrintOp as explicitly illegal to get it to fail(addIllegalOp<PrintOp>()). You can also use llvm's debug infrastructure, via `-debug-only=dialect-conversion', to give some more logging on what is going on.

@jpienaar
Copy link
Member

@doru1004 is there perhaps more information you could add here? Perhaps a branch in your fork that you've been making changes on to demonstrate the failure?

@doru1004
Copy link
Author

@jpienaar I managed to fix the error by converting my tensors into memrefs. The way I did this was by using the toy::CastOp.

Is there a way to do this cast without the toy::CastOp? I used a piece of code which was in the old tutorial:

Value *memRefTypeCast(PatternRewriter &builder, Value *val) {
  if (val->getType().isa<MemRefType>())
    return val;
  auto tensorType = val->getType().dyn_cast<TensorType>();
  if (!tensorType)
    return val;
  return builder.create<toy::CastOp>(val->getLoc(), convertTensorToMemRef(tensorType), val)
      .getResult();
}

to achieve this but now there seems to be a new way to do this type of conversion. How would I go about controlling/triggering this conversion? Converting tensors to memrefs seems like something that I should be able to just invoke since it occurs between existing MLIR types.

@jpienaar
Copy link
Member

jpienaar commented Dec 2, 2019

Hey, have you tried tensor_load/tensor_store ops here? That seems to do what you want.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants