-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add View/StreamLayout Operation #2342
base: main
Are you sure you want to change the base?
Conversation
This change adds 2 new TTIR layout related ops and makes a few refactors to better share common interface and verifier code between them. The verifiers are also significantly improved and check for many more illegal cases. ## StreamLayout Operation StreamLayout operation, similar to the ToLayout operation, but with the difference that this op is not eagerly evaluated and is instead used as a means for defining a stream. The primary usecases include, to enable streaming a large tensor out of dram via a small L1 buffer and also as a means for forming reduce or gather multicast operations. A stream definition includes: - The tensor to be streamed. - The storage buffer to be used for streaming. - Backing memory for a list of DMA transactions to be filled in by the backend. - A result, which is also able to take a view over the input, i.e. same semantics as the ViewLayout op. Additional constraints: - It is not capable of changing the data type nor the memory space of the tensor. ```llvm %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_> %stream = "ttir.stream_layout"(%arg0, %alloc_0) : (memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>, memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>) -> memref<2x4x4x6x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) ``` ## ViewLayout Operation ViewLayout operation, nearly identical to ToLayout operation, but with the difference that this op is not eagerly evaluated. Its primary usecase is to allow reinterpreting the layout of a tensor without actually moving the data. Additional notes/constraints: - It is not capable of changing the data type nor the memory space of the tensor. - All ViewLayout ops can trivially be converted to ToLayout ops. ```llvm #layout = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>> #layout1 = #tt.metal_layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>> %1 = "ttir.view_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1> ``` Closes #587
fullMemrefShape.append(gridShape.begin(), gridShape.end()); | ||
fullMemrefShape.append(shardShape.begin(), shardShape.end()); | ||
return buildMemRef<MemorySpace, MemorySpaceAttr>( | ||
getContext(), fullMemrefShape, getElementType(), getMemorySpace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I am finally trying to get to grips with the full MetalLayoutAttr API.)
This builds a memref buffer of a shape that combines the original attr's grid and shard shapes. However, the shard shape will be computed in the convert-tile-to-scalar mode while the buildMemRef()
will use getElementType()
unconditionally and not getScalarElementType()
-- is that (always) correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a bit confusing and this actually tripped me up as I was making the change. buildMemRef
always expects a scalar shape passed into it, which means we need to expand out and shards shapes that are tilized.
void $cppClass::getEffects( | ||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects) { | ||
getDpsEffects(*this, effects); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, a future reader of this code (who won't have access to this PR's change set) will find it difficult to understand where getDpsEffects()
is coming from.
I have a similar utility method that I placed into include/ttmlir/Utils.h
and that I invoke via full namespace prefix:
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
...
If you think this makes sense maybe you can do smth similar with getDpsEffects()
. At the very least, invoke it via ::ttmlir::...
, maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the tricky thing is that it's in the same namespace right here, not in the ::ttmlir
namespace. It also doesn't really feel like it belongs in utils since it's very specific to TTIR ops implementing memory effects interface. Perhaps I can give the function a more specific name, something that's easier to grep?
OpOperand &operand, const bufferization::AnalysisState &) { | ||
bufferization::AliasingValueList result; | ||
return result; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you arrive at this choice of BufferizableOpInterface methods to override? I am trying to relate to docs and/or llvm sources and it is not "easy".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through the interface tablegen definition: https://github.com/llvm/llvm-project/blob/313b71fc1a9ae17ea5ecba8afcb4e5b80e1f4043/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
And payed special attention to all the ones that have an "llvm_unreachable" implementation:
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
return ::mlir::failure();
}]
That in combination with running the pass and seeing which ones failed.
```llvm | ||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_> | ||
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_> | ||
%stream = "ttir.stream_layout"(%arg0, %alloc_0) : (memref<2x4x4x6x!tt.tile<32x32, f32>, #l1_>, memref<2x4x1x1x!tt.tile<32x32, f32>, #l1_>) -> memref<2x4x4x6x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I believe you meant
(%alloc, %alloc_0)
? - I think it would be helpful to finish spelling out that
tt.stream
result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the feedback!
OpOperand &operand, const bufferization::AnalysisState &) { | ||
bufferization::AliasingValueList result; | ||
return result; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through the interface tablegen definition: https://github.com/llvm/llvm-project/blob/313b71fc1a9ae17ea5ecba8afcb4e5b80e1f4043/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
And payed special attention to all the ones that have an "llvm_unreachable" implementation:
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
return ::mlir::failure();
}]
That in combination with running the pass and seeing which ones failed.
void $cppClass::getEffects( | ||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects) { | ||
getDpsEffects(*this, effects); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the tricky thing is that it's in the same namespace right here, not in the ::ttmlir
namespace. It also doesn't really feel like it belongs in utils since it's very specific to TTIR ops implementing memory effects interface. Perhaps I can give the function a more specific name, something that's easier to grep?
fullMemrefShape.append(gridShape.begin(), gridShape.end()); | ||
fullMemrefShape.append(shardShape.begin(), shardShape.end()); | ||
return buildMemRef<MemorySpace, MemorySpaceAttr>( | ||
getContext(), fullMemrefShape, getElementType(), getMemorySpace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's a bit confusing and this actually tripped me up as I was making the change. buildMemRef
always expects a scalar shape passed into it, which means we need to expand out and shards shapes that are tilized.
This change adds 2 new TTIR layout related ops and makes a few refactors to better share common interface and verifier code between them. The verifiers are also significantly improved and check for many more illegal cases.
StreamLayout Operation
StreamLayout operation, similar to the ToLayout operation, but with the difference that this op is not eagerly evaluated and is instead used as a means for defining a stream. The primary usecases include, to enable streaming a large tensor out of dram via a small L1 buffer and also as a means for forming reduce or gather multicast operations. A stream definition includes:
Additional constraints:
ViewLayout Operation
ViewLayout operation, nearly identical to ToLayout operation, but with the difference that this op is not eagerly evaluated. Its primary usecase is to allow reinterpreting the layout of a tensor without actually moving the data.
Additional notes/constraints:
Closes #587