Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic conversion between ttir and linalg #1558

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

vwellsTT
Copy link
Contributor

@vwellsTT vwellsTT commented Dec 10, 2024

Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU.

Context:

The entire task will be split into (tentatively) 7 PRs, as follows:

  1. Hoist specific ops into isolated funcs in a separate module
  2. Convert TTIR ops to linalg ops within the module of hoisted funcs
  3. Build a pipeline to lower linalg to llvm from existing conversion passes
  4. Translate LLVM Dialect into a dynamic library for packing into flatbuffer
  5. Generate helper functions so that we can call all of our hoisted funcs with a common signature
  6. Insert TTNN instructions to move operands to host before executing hoisted func, then back to device afterwards
  7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new passes, generate dylibs, and embed them into output flatbuffers, and update update runtime to consume dylibs from flatbuffers

This PR represents the 2nd subtask above--it converts TTIR ops in the "cpu" module into their linalg equivalents. In this case, we only enable a few basic operations for now, in an attempt to get basic end-to-end conversions working before fleshing out full conversion. Definitely open to adding some other specific ops to conversion before merging, not sure which ops would make most sense.

Example

Input:

module  {
  func.func @add(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
    %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
    return %1 : tensor<32x32xf32>
  }
}

Output:

module {
  func.func @add(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
    %0 = linalg.add ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
    return %0 : tensor<32x32xf32>
  }
}

@vwellsTT vwellsTT marked this pull request as ready for review December 10, 2024 20:46
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from f36bad4 to 5272da5 Compare December 10, 2024 20:48
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from 5272da5 to faf1bb4 Compare December 10, 2024 22:42
@ctodTT
Copy link
Contributor

ctodTT commented Dec 12, 2024

Could be wrong, but I don't see any CMakeLists.txt changed here, do these new Conversion/ files get compiled a different way?

@vwellsTT
Copy link
Contributor Author

vwellsTT commented Dec 12, 2024

Could be wrong, but I don't see any CMakeLists.txt changed here, do these new Conversion/ files get compiled a different way?

Let me double-check, there were quite a few interleaved commits I tried to manually split out, so wouldn't surprise me if I missed one. Thanks!

Update: yup, was missing several commits...hopefully fixed now

@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from faf1bb4 to cb27f5a Compare December 16, 2024 14:52
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch 11 times, most recently from 22dcfe3 to 1c254f7 Compare December 16, 2024 16:51
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from 1c254f7 to f8ede73 Compare December 16, 2024 16:51
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from 84f34dd to 85feec5 Compare December 16, 2024 18:22
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch 2 times, most recently from f98055a to cdaa88f Compare December 17, 2024 20:06
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments inline, otherwise looks great!

#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

// #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray comment

lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp Outdated Show resolved Hide resolved
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from cdaa88f to ad7f058 Compare December 18, 2024 15:16
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch 12 times, most recently from 656c4ac to e6624f6 Compare December 20, 2024 16:56
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from e6624f6 to 2fd1b2c Compare December 20, 2024 16:57
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from b9bdddc to 5c663f5 Compare December 20, 2024 17:13

// Helper func to check which dims need to be broadcast and which need to be
// collapsed. Assumes that inputShape is broadcast-able to targetShape.
static void getDimsToBroadcastAndCollapse(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is much more complicated than I originally expected, but linalg.broadcastOp asserts something like "input rank + len(dimsToBroadcastAlong) == output rank". We want to support torch-like broadcasting, where ranks with size 1 are implicitly broadcasted => we need to do a tensor.collapseOp in such cases, which requires hairy logic.

I would appreciate careful review/re-review of this logic, especially this function 🙂

include/ttmlir/Conversion/Passes.td Show resolved Hide resolved

using TensorRanks = SmallVector<int64_t, 2>;

static LogicalResult computeBroadcastedShape(SmallVector<Value, 3> inputs,
Copy link
Contributor

@mtopalovicTT mtopalovicTT Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use OpTrait::util::getBroadcastedShape util function to calculate broadcasted shape. It will save you all work you did here.
Also you don't need to check if operands are broadcastable since we already check this using traits for all eltwise operations (check verifyBroadcastable).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I have switched to this func. Let me know if you think I should replace this if check with some sort of assert etc instead given checks must have occurred at TTIR level

test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir Show resolved Hide resolved
@vwellsTT vwellsTT force-pushed the vwells/ttir_to_linalg_conversion branch from 80986dd to 7097494 Compare January 13, 2025 16:18
@vwellsTT vwellsTT requested a review from vcanicTT as a code owner January 13, 2025 16:18
@vwellsTT vwellsTT enabled auto-merge (squash) January 13, 2025 16:18
@vwellsTT vwellsTT disabled auto-merge January 13, 2025 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants