-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[TritonGPU] Preserve memdesc_reshape encoding on type propagation #9973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1456,9 +1456,27 @@ void replaceUsesAndPropagateType( | |
| newVal = ttg::MemDescTransOp::create(builder, trans.getLoc(), val, | ||
| trans.getOrder()); | ||
| } else if (auto reshape = dyn_cast<ttg::MemDescReshapeOp>(user)) { | ||
| auto shape = reshape.getType().getShape(); | ||
| newVal = | ||
| ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), val, shape); | ||
| // Use inferReturnTypes to compute the correct allocShape and mutability | ||
| // from the new source, but preserve the original reshape's encoding | ||
| // rather than re-inferring it (which can change e.g. nvmma_shared to | ||
| // shared_linear). | ||
| ttg::MemDescType inferredType; | ||
| LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes( | ||
| builder.getContext(), reshape.getLoc(), | ||
| cast<ttg::MemDescType>(val.getType()), | ||
| reshape.getType().getShape(), inferredType); | ||
| assert(succeeded(result) && "failed to infer reshape return type"); | ||
| assert(ttg::areLayoutsEquivalent( | ||
| inferredType.getShape(), | ||
| cast<ttg::LayoutEncodingTrait>(reshape.getType().getEncoding()), | ||
| cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) && | ||
| "preserved encoding is not equivalent to inferred encoding"); | ||
| Type newDstType = ttg::MemDescType::get( | ||
| inferredType.getShape(), inferredType.getElementType(), | ||
| reshape.getType().getEncoding(), inferredType.getMemorySpace(), | ||
| inferredType.getMutableMemory(), inferredType.getAllocShape()); | ||
|
Comment on lines
+1474
to
+1477
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also unnecessary as you are creating the same type we already had?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to use the explicit type overload |
||
| newVal = ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), | ||
| newDstType, val); | ||
| } | ||
| assert(newVal && "unhandled memdesc view"); | ||
| newVal.getDefiningOp()->setAttrs(user->getAttrs()); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this is unnecessary. You can asser tthat the initial types are the same and that's it.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the
inferReturnTypesto re-compute theallocShapefield which changes with aref insertion, while still keep the#ttg.nvvma_shared. The issue is that prior to this fix, the shape-only overloadMemDescReshapeOp::create(builder, loc, val, shape)callsinferReturnTypesand it falls back the encoding to#ttg.shared_linear, which downstream ops such as TMA ops would reject because they check for NVMMASharedEncoding, even though verifier ofMemDescReshapeOpalready checks the inferred#ttg.shared_linearand#ttg.nvmma_sharedare layout-equivalent.