Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add View/StreamLayout Operation #2342

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add View/StreamLayout Operation #2342

wants to merge 2 commits into from

Conversation

nsmithtt
Copy link
Contributor

@nsmithtt nsmithtt commented Mar 2, 2025

This change adds 2 new TTIR layout related ops and makes a few refactors to better share common interface and verifier code between them. The verifiers are also significantly improved and check for many more illegal cases.

StreamLayout Operation

StreamLayout operation, similar to the ToLayout operation, but with the difference that this op is not eagerly evaluated and is instead used as a means for defining a stream. The primary usecases include, to enable streaming a large tensor out of dram via a small L1 buffer and also as a means for forming reduce or gather multicast operations. A stream definition includes:

  • The tensor to be streamed.
  • The storage buffer to be used for streaming.
  • Backing memory for a list of DMA transactions to be filled in by the backend.
  • A result, which is also able to take a view over the input, i.e. same semantics as the ViewLayout op.

Additional constraints:

  • It is not capable of changing the data type nor the memory space of the tensor.
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>
%stream = "ttir.stream_layout"(%arg0, %alloc_0) : (memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>, memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>) -> memref<2x4x4x6x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3)

ViewLayout Operation

ViewLayout operation, nearly identical to ToLayout operation, but with the difference that this op is not eagerly evaluated. Its primary usecase is to allow reinterpreting the layout of a tensor without actually moving the data.

Additional notes/constraints:

  • It is not capable of changing the data type nor the memory space of the tensor.
  • All ViewLayout ops can trivially be converted to ToLayout ops.
#layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
%1 = "ttir.view_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>

Closes #587

This change adds 2 new TTIR layout related ops and makes a few
refactors to better share common interface and verifier code between
them.  The verifiers are also significantly improved and check for many
more illegal cases.

## StreamLayout Operation

StreamLayout operation, similar to the ToLayout operation, but with the difference that this op is not
eagerly evaluated and is instead used as a means for defining a stream. The primary usecases include,
to enable streaming a large tensor out of dram via a small L1 buffer and also as a means for forming
reduce or gather multicast operations. A stream definition includes:

- The tensor to be streamed.
- The storage buffer to be used for streaming.
- Backing memory for a list of DMA transactions to be filled in by the backend.
- A result, which is also able to take a view over the input, i.e. same semantics as the ViewLayout op.

Additional constraints:
- It is not capable of changing the data type nor the memory space of the tensor.

```llvm
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>
%stream = "ttir.stream_layout"(%arg0, %alloc_0) : (memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>, memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>) -> memref<2x4x4x6x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3)
```

## ViewLayout Operation

ViewLayout operation, nearly identical to ToLayout operation, but with the difference that this op is not
eagerly evaluated. Its primary usecase is to allow reinterpreting the layout of a tensor without actually
moving the data.

Additional notes/constraints:
- It is not capable of changing the data type nor the memory space of the tensor.
- All ViewLayout ops can trivially be converted to ToLayout ops.

```llvm
#layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
%1 = "ttir.view_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>
```

Closes #587
fullMemrefShape.append(gridShape.begin(), gridShape.end());
fullMemrefShape.append(shardShape.begin(), shardShape.end());
return buildMemRef<MemorySpace, MemorySpaceAttr>(
getContext(), fullMemrefShape, getElementType(), getMemorySpace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I am finally trying to get to grips with the full MetalLayoutAttr API.)

This builds a memref buffer of a shape that combines the original attr's grid and shard shapes. However, the shard shape will be computed in the convert-tile-to-scalar mode while the buildMemRef() will use getElementType() unconditionally and not getScalarElementType() -- is that (always) correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a bit confusing and this actually tripped me up as I was making the change. buildMemRef always expects a scalar shape passed into it, which means we need to expand out and shards shapes that are tilized.

void $cppClass::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDpsEffects(*this, effects);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, a future reader of this code (who won't have access to this PR's change set) will find it difficult to understand where getDpsEffects() is coming from.

I have a similar utility method that I placed into include/ttmlir/Utils.h and that I invoke via full namespace prefix:

    let extraClassDeclaration = [{
      MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
...

If you think this makes sense maybe you can do smth similar with getDpsEffects(). At the very least, invoke it via ::ttmlir::..., maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the tricky thing is that it's in the same namespace right here, not in the ::ttmlir namespace. It also doesn't really feel like it belongs in utils since it's very specific to TTIR ops implementing memory effects interface. Perhaps I can give the function a more specific name, something that's easier to grep?

OpOperand &operand, const bufferization::AnalysisState &) {
bufferization::AliasingValueList result;
return result;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you arrive at this choice of BufferizableOpInterface methods to override? I am trying to relate to docs and/or llvm sources and it is not "easy".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the interface tablegen definition: https://github.com/llvm/llvm-project/blob/313b71fc1a9ae17ea5ecba8afcb4e5b80e1f4043/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

And payed special attention to all the ones that have an "llvm_unreachable" implementation:

        /*defaultImplementation=*/[{
          llvm_unreachable("bufferize not implemented");
          return ::mlir::failure();
        }]

That in combination with running the pass and seeing which ones failed.

```llvm
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>
%stream = "ttir.stream_layout"(%arg0, %alloc_0) : (memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>, memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>) -> memref<2x4x4x6x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I believe you meant (%alloc, %alloc_0)?
  2. I think it would be helpful to finish spelling out that tt.stream result

Copy link
Contributor Author

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback!

OpOperand &operand, const bufferization::AnalysisState &) {
bufferization::AliasingValueList result;
return result;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through the interface tablegen definition: https://github.com/llvm/llvm-project/blob/313b71fc1a9ae17ea5ecba8afcb4e5b80e1f4043/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

And payed special attention to all the ones that have an "llvm_unreachable" implementation:

        /*defaultImplementation=*/[{
          llvm_unreachable("bufferize not implemented");
          return ::mlir::failure();
        }]

That in combination with running the pass and seeing which ones failed.

void $cppClass::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getDpsEffects(*this, effects);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the tricky thing is that it's in the same namespace right here, not in the ::ttmlir namespace. It also doesn't really feel like it belongs in utils since it's very specific to TTIR ops implementing memory effects interface. Perhaps I can give the function a more specific name, something that's easier to grep?

fullMemrefShape.append(gridShape.begin(), gridShape.end());
fullMemrefShape.append(shardShape.begin(), shardShape.end());
return buildMemRef<MemorySpace, MemorySpaceAttr>(
getContext(), fullMemrefShape, getElementType(), getMemorySpace());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a bit confusing and this actually tripped me up as I was making the change. buildMemRef always expects a scalar shape passed into it, which means we need to expand out and shards shapes that are tilized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TTIR view_layout op
2 participants