Skip to content

[BACKEND] Implement multiCTA support for TMA gather/scatter#9977

Merged
lezcano merged 3 commits into
mainfrom
scatter_gather_multicast
Apr 16, 2026
Merged

[BACKEND] Implement multiCTA support for TMA gather/scatter#9977
lezcano merged 3 commits into
mainfrom
scatter_gather_multicast

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented Apr 9, 2026

We also add tighter invariants for gather/scatter ops as well

@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Apr 9, 2026

@peterbell10 can you review this one?

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e63aee2ccb

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
@lezcano lezcano force-pushed the scatter_gather_multicast branch from e63aee2 to 93b0fb1 Compare April 9, 2026 17:33
Comment thread third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Copy link
Copy Markdown
Collaborator

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Mogball
Copy link
Copy Markdown
Collaborator

Mogball commented Apr 10, 2026

I might get a chance to try this out actually

@lezcano lezcano force-pushed the scatter_gather_multicast branch from 9152ee1 to 52744e4 Compare April 13, 2026 09:41
@lezcano lezcano changed the title [BACKEND] Implement multiCTA support for TMA gather [BACKEND] Implement multiCTA support for TMA gather/scatter Apr 13, 2026
Comment thread lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp Outdated
Comment thread lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp Outdated
Comment thread lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp Outdated
Comment thread lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp Outdated
Comment thread test/Conversion/tma_to_llvm.mlir Outdated
Comment thread third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp Outdated
Comment thread python/test/gluon/test_core.py
@lezcano lezcano force-pushed the scatter_gather_multicast branch from 43bbba8 to bfe7657 Compare April 14, 2026 08:28
@lezcano lezcano requested a review from peterbell10 April 14, 2026 08:46
@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Apr 14, 2026

done, added these ops to the comprehensive TMA into mma test.

@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Apr 14, 2026

ugh, the tests are failing. Let me look into those.

@lezcano lezcano force-pushed the scatter_gather_multicast branch from 3967b38 to d55f753 Compare April 14, 2026 15:14
@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Apr 14, 2026

found a real latent multicta issue and fixed it. Thank you for pushing for more comprehensive tests.

Comment thread python/test/gluon/test_core.py Outdated
mbarrier.init(bar, count=1)

gather_offsets = ttgl.load(gather_idx_ptr + ttgl.arange(0, BLOCK_M, layout=x_offsets_layout))
mbarrier.expect(bar, blackwell_tma.nbytes_per_cta_gather(in_desc, gather_offsets))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have smem.nbytes_per_cta instead?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I was thinking about that the other day. Let me do that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, for SharedLinearEncoding we compute a pseudo cga_layout of sorts to divide the shape by it.

We also add tighter invariants for the cga_layout part of all TMA ops
@lezcano lezcano force-pushed the scatter_gather_multicast branch 2 times, most recently from c1148c1 to 6a54bcc Compare April 15, 2026 16:09
lezcano added a commit that referenced this pull request Apr 15, 2026
tbh, I think this is also the way we should represent this in our IR.

I will add an end-to-end test once
#9977 is merged (I will extend
the test in that PR to test this op)
@lezcano lezcano requested a review from peterbell10 April 15, 2026 19:54
@lezcano lezcano force-pushed the scatter_gather_multicast branch from 89e495a to 8523533 Compare April 16, 2026 12:09
Comment thread python/triton/experimental/gluon/language/_core.py Outdated
@lezcano lezcano enabled auto-merge (squash) April 16, 2026 12:38
@lezcano lezcano merged commit eb5efe2 into main Apr 16, 2026
23 of 27 checks passed
@lezcano lezcano deleted the scatter_gather_multicast branch April 16, 2026 17:03
raymondtay pushed a commit to raymondtay/triton that referenced this pull request Apr 18, 2026
tbh, I think this is also the way we should represent this in our IR.

I will add an end-to-end test once
triton-lang#9977 is merged (I will extend
the test in that PR to test this op)
raymondtay pushed a commit to raymondtay/triton that referenced this pull request Apr 18, 2026
…ang#9977)

We also add tighter invariants for gather/scatter ops as well
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants