-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[TritonGPU] Split MemDescSubview into MemDescIndex and MemDescSubslice #7622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The first one will be used just for pipelining and it's equivalent to `x[i]`, the second one takes a full slice of constant shape `x[:i1, :i2]`, for example.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome great cleanup! Few minor comments
lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp
Outdated
Show resolved
Hide resolved
lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp
Outdated
Show resolved
Hide resolved
test/Analysis/test-alias.mlir
Outdated
| %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> | ||
| // expected-remark @below {{%2 -> %0}} | ||
| %0 = ttg.memdesc_subview %cst[%idx, %idx] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> | ||
| %0 = ttg.memdesc_subslice %cst {offsets=array<i32: 0, 0>} : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can probably have a better printing like %cts[0, 0] but it can be done as a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vibecoded it with an agent in 5 min in c89d9c6
I didn't even know about that custom API...
| bool is1D = | ||
| srcTy.getRank() == 1 && dstTy.getRank() == 1 && dstTy.getDimSize(0) == 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we need the 1d case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we pipeline barriers and things like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Gluon we do Nx1xi64 to get around having to support this in the APIs. Changing that in the compiler however would mean needing to update a LOT of tests...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after this PR I'm not scared of having to change a lot of tests (in reality it was horrible)
on a different note, this would be a lovely task for an agent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah would be nice to clean up
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what #7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
…7696) triton-lang#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what triton-lang#7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
…7696) triton-lang#7622 introduced `ttg.memdesc_index` which applies a constant offset to the base pointer of the smem object. For padded layouts we need to add padding based on the offset, similar to what triton-lang#7404 did for the old subview operation. I also adjusted the lit test to check we actually generate padding from the ttg.memdesc_index. The previous version did not fail because it matched the lowering of the `ttg.local_load/store` as well.
third_party/tlx/run_all.sh [TLX-3.5] Fix memdesc_subview refactoring from triton-lang#7622 pytest python/test/unit/language/test_tlx.py::test_load_store_smem_with_tl_load pytest python/test/unit/language/test_tlx.py::test_local_store pytest python/test/unit/language/test_tlx.py::test_local_load TODO. fix TLX layout propagation LITs using memdesc_subview [TLX-3.5] Fix barrier ops caused by 1D tensor handling by memdesc_index python/test/unit/language/test_tlx.py::test_wait_arrive_non_ws The root cause is memdesc_index fail its `verify()` for 1D tensor case. It's caused by a bug in merging conflicts. More related discussions: https://github.com/triton-lang/triton/pull/7622/files#r2227788997 [TLX-3.5] Fix all UTs python/test/unit/language/test_tlx.py::test_async_dot
The first one will be used just for pipelining and it's equivalent to
x[i], the second one takes a full slice of constant shapex[:i1, :i2],for example.