Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linalg -> mlir::LLVM Dialect conversion pass #1559

Merged
merged 6 commits into from
Dec 19, 2024

Conversation

vwellsTT
Copy link
Contributor

@vwellsTT vwellsTT commented Dec 10, 2024

Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU.

Context:

The entire task will be split into (tentatively) 7 PRs, as follows:

  1. Hoist specific ops into isolated funcs in a separate module
  2. Convert TTIR ops to linalg ops within the module of hoisted funcs
  3. Build a pipeline to lower linalg to llvm from existing conversion passes
  4. Translate LLVM Dialect into a dynamic library for packing into flatbuffer
  5. Generate helper functions so that we can call all of our hoisted funcs with a common signature
  6. Insert TTNN instructions to move operands to host before executing hoisted func, then back to device afterwards
  7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new passes, generate dylibs, and embed them into output flatbuffers, and update update runtime to consume dylibs from flatbuffers

This PR represents the 3rd PR above. Here, we build a pipeline from existing passes to lower linalg Dialect into LLVM Dialect so that we can proceed to compile into an executable .so in later stages

Example

Input:

module {
  func.func @add(
    %arg0: tensor<32x32xf32>,  // First input tensor
    %arg1: tensor<32x32xf32>,  // Second input tensor
    %arg2: tensor<32x32xf32>   // Output tensor (result stored here)
  ) -> tensor<32x32xf32> {
    // Perform linalg.add and store the result in %arg2
    %1 = linalg.add ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
    return %1 : tensor<32x32xf32>
  }
}

Output:

module {
  llvm.func @memrefCopy(i64, !llvm.ptr, !llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.func @add(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.mlir.zero : !llvm.ptr
    %2 = llvm.mlir.constant(2 : i64) : i64
    %3 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %4 = llvm.mlir.constant(1 : index) : i64
    %5 = llvm.mlir.constant(32 : index) : i64
    %6 = llvm.mlir.constant(0 : index) : i64
    %7 = llvm.insertvalue %arg14, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %8 = llvm.insertvalue %arg15, %7[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %9 = llvm.insertvalue %arg16, %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %10 = llvm.insertvalue %arg17, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %11 = llvm.insertvalue %arg19, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %12 = llvm.insertvalue %arg18, %11[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %13 = llvm.insertvalue %arg20, %12[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    llvm.br ^bb1(%6 : i64)
  ^bb1(%14: i64):  // 2 preds: ^bb0, ^bb5
    %15 = llvm.icmp "slt" %14, %5 : i64
    llvm.cond_br %15, ^bb2, ^bb6
  ^bb2:  // pred: ^bb1
    llvm.br ^bb3(%6 : i64)
  ^bb3(%16: i64):  // 2 preds: ^bb2, ^bb4
    %17 = llvm.icmp "slt" %16, %5 : i64
    llvm.cond_br %17, ^bb4, ^bb5
  ^bb4:  // pred: ^bb3
    %18 = llvm.getelementptr %arg1[%arg2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %19 = llvm.mul %14, %arg5 : i64
    %20 = llvm.mul %16, %arg6 : i64
    %21 = llvm.add %19, %20 : i64
    %22 = llvm.getelementptr %18[%21] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %23 = llvm.load %22 : !llvm.ptr -> f32
    %24 = llvm.getelementptr %arg8[%arg9] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %25 = llvm.mul %14, %arg12 : i64
    %26 = llvm.mul %16, %arg13 : i64
    %27 = llvm.add %25, %26 : i64
    %28 = llvm.getelementptr %24[%27] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %29 = llvm.load %28 : !llvm.ptr -> f32
    %30 = llvm.fadd %23, %29  : f32
    %31 = llvm.getelementptr %arg15[%arg16] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %32 = llvm.mul %14, %arg19 : i64
    %33 = llvm.mul %16, %arg20 : i64
    %34 = llvm.add %32, %33 : i64
    %35 = llvm.getelementptr %31[%34] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %30, %35 : f32, !llvm.ptr
    %36 = llvm.add %16, %4 : i64
    llvm.br ^bb3(%36 : i64)
  ^bb5:  // pred: ^bb3
    %37 = llvm.add %14, %4 : i64
    llvm.br ^bb1(%37 : i64)
  ^bb6:  // pred: ^bb1
    %38 = llvm.getelementptr %1[1024] : (!llvm.ptr) -> !llvm.ptr, f32
    %39 = llvm.ptrtoint %38 : !llvm.ptr to i64
    %40 = llvm.call @malloc(%39) : (i64) -> !llvm.ptr
    %41 = llvm.insertvalue %40, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %42 = llvm.insertvalue %40, %41[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %43 = llvm.insertvalue %6, %42[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %44 = llvm.insertvalue %5, %43[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %45 = llvm.insertvalue %5, %44[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %46 = llvm.insertvalue %5, %45[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %47 = llvm.insertvalue %4, %46[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %48 = llvm.intr.stacksave : !llvm.ptr
    %49 = llvm.alloca %4 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
    llvm.store %13, %49 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
    %50 = llvm.insertvalue %2, %3[0] : !llvm.struct<(i64, ptr)> 
    %51 = llvm.insertvalue %49, %50[1] : !llvm.struct<(i64, ptr)> 
    %52 = llvm.alloca %4 x !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> : (i64) -> !llvm.ptr
    llvm.store %47, %52 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>, !llvm.ptr
    %53 = llvm.insertvalue %52, %50[1] : !llvm.struct<(i64, ptr)> 
    %54 = llvm.alloca %4 x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
    llvm.store %51, %54 : !llvm.struct<(i64, ptr)>, !llvm.ptr
    %55 = llvm.alloca %4 x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
    llvm.store %53, %55 : !llvm.struct<(i64, ptr)>, !llvm.ptr
    %56 = llvm.getelementptr %1[1] : (!llvm.ptr) -> !llvm.ptr, f32
    %57 = llvm.ptrtoint %56 : !llvm.ptr to i64
    llvm.call @memrefCopy(%57, %54, %55) : (i64, !llvm.ptr, !llvm.ptr) -> ()
    llvm.intr.stackrestore %48 : !llvm.ptr
    llvm.return %47 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
  }
}

@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from f36bad4 to 5272da5 Compare December 10, 2024 20:48
@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from 23ec6c4 to 49eefb0 Compare December 10, 2024 20:53
@vwellsTT vwellsTT marked this pull request as ready for review December 10, 2024 20:53
@nsmithtt
Copy link
Contributor

vwellsTT wants to merge 3 commits into vwells/ttir_to_linalg_conversion from vwells/linalg_to_llvm_conversion

@vwellsTT I think this PR needs to be opened against main if you want CI to run.

@vwellsTT
Copy link
Contributor Author

@vwellsTT I think this PR needs to be opened against main if you want CI to run.

Hmm, imo that makes the diff confusing when I open 3 subsequent PRs like this. I guess I could cherry-pick this onto main, and in this case perhaps it doesn't actually depend on the previous branch...but I'd sort of prefer to just do this and wait for the previous branch to merge before running CI personally

@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from 5272da5 to faf1bb4 Compare December 10, 2024 22:42
@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from c204ee8 to 8550c24 Compare December 10, 2024 22:44
@nsmithtt
Copy link
Contributor

nsmithtt commented Dec 10, 2024

So, we could do a preemptive soft review, but there are no processes/requirements for reviewing PRs that are not going into main. That is, we'll have to review this PR again (officially) once it is requested against main.

For the sake of reviewing churn it's probably best to open the PRs that can be rebased on main against main. For changes that are dependent on each other we'll have to open those PRs serially.

@vwellsTT
Copy link
Contributor Author

vwellsTT commented Dec 10, 2024 via email

@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from 43a8c7e to f38eccc Compare December 11, 2024 14:56
@vwellsTT vwellsTT changed the base branch from vwells/ttir_to_linalg_conversion to main December 11, 2024 14:56
@nsmithtt
Copy link
Contributor

@vwellsTT, this looks great, couple of things we should add:

  • Let's add more context to the PR description. Let's call out that this is part 2 of ?? parts with some additional sentences describing that this is part of CPU fallback path. It's probably worth having a top level statement that outlines the whole plan broken into parts and then stating this PR tackles part 2 of above plan.
  • We should add a dialect conversion test that proves TTIR -> Linalg for the add op

@vwellsTT
Copy link
Contributor Author

Sounds good; I'll make a quick template for all of these PRs like you describe. For the test, I may have some follow-up questions, but seems like a good idea

// TODO: we might want some more options to say lower through affine loops
// instead of scf directly, etc. which could be new options
Option<bool> cleanupOutputEnabled{
*this, "enable-remove-dead-values",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename "enable-remove-dead-values" to something more meaningful? As I understand this option will enable canonicalize, SCC, CSE, SymbolDCE, not just remove-dead-value pass...

@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch 3 times, most recently from 723f79a to 47f7be7 Compare December 16, 2024 22:24
@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from 47f7be7 to bc75536 Compare December 16, 2024 22:28
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments to address inline, otherwise looks great, thanks!

// TODO (vwells): we might want some more options to say lower through affine
// loops instead of scf directly, etc. which could be new options
Option<bool> cleanupOutputEnabled{
*this, "enable-post-pipeline-cleanup",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somewhat preferred the remove dead values name since that's effectively what most of these passes are doing. That said, perhaps the name enable-optimization-passes might be more suitable? This current name is a bit vague.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I personally agree that remove-dead-values seemed like a reasonable summary, but don't feel super strongly. I think enable-optimization-passes is fine too, I'll switch to that

@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from eb8168d to 191fea2 Compare December 18, 2024 18:42
@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch 2 times, most recently from 277aa4e to 0459bda Compare December 19, 2024 17:51
@vwellsTT vwellsTT force-pushed the vwells/linalg_to_llvm_conversion branch from 0459bda to 8e69e5a Compare December 19, 2024 17:53
@vwellsTT vwellsTT enabled auto-merge (squash) December 19, 2024 17:55
@vwellsTT vwellsTT merged commit e748e71 into main Dec 19, 2024
21 checks passed
@vwellsTT vwellsTT deleted the vwells/linalg_to_llvm_conversion branch December 19, 2024 18:39
jdesousa-TT added a commit that referenced this pull request Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants