Skip to content

Support returning tensors in TritonToTritonGPU#10189

Merged
ThomasRaoux merged 2 commits into
triton-lang:mainfrom
leijurv:support-returning-tensors-in-triton-to-tritongpu
May 12, 2026
Merged

Support returning tensors in TritonToTritonGPU#10189
ThomasRaoux merged 2 commits into
triton-lang:mainfrom
leijurv:support-returning-tensors-in-triton-to-tritongpu

Conversation

@leijurv
Copy link
Copy Markdown
Contributor

@leijurv leijurv commented Apr 30, 2026

The conversion patterns in TritonToTritonGPU (TritonFuncOpPattern, TritonCallOpPattern, TritonReturnOpPattern) did not handle tt.func ops returning tensors: TritonFuncOpPattern reused the original FunctionType verbatim, and the triton::FuncOp legality check inspected only argument types, so tensor result types came out without any layout encoding. There was already a fork of MLIR upstream's signature-conversion pattern in Dialect/Triton/Transforms/FunctionTypeConversion.h (added because upstream is unaware of tt.call/tt.return) which handles inputs, results, one-to-many conversions, and arg-attribute remapping — so this PR drops those three custom patterns and reuses this one. The legality predicate is also extended to require encodings on result tensor types.

Adds a regression test in test/Conversion/triton_to_tritongpu.mlir covering a tt.func returning a tensor and a tt.call site picking up the encoded result type.

Making this PR on behalf of @saagarjha (he is waiting on #8913)

  • I am not making a trivial change, such as fixing a typo in a comment.
  • I have written a PR description following these
    rules.
  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.
  • I have added tests.
  • The lit tests I have added follow these best practices,
    including the "tests should be minimal" section. (Usually running Python code
    and using the instructions it generates is not minimal.)

Co-Authored-By: Saagar Jha <saagar@saagarjha.com>
@leijurv leijurv requested a review from ptillet as a code owner April 30, 2026 23:55
Copy link
Copy Markdown
Contributor

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not clear about this PR. Do you need to handle returned tensors in LLVM lowering?

@saagarjha
Copy link
Copy Markdown
Contributor

If you don't apply this PR then you TritonToTritonGPU will fail because the tensor that is returned will not have a layout.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented May 1, 2026

If you don't apply this PR then you TritonToTritonGPU will fail because the tensor that is returned will not have a layout.

Why do we need to return tensors? Even if we did, I'm not sure if LLVM lowering handles return ops well.

@saagarjha
Copy link
Copy Markdown
Contributor

If you make a call it's going to return a tensor? We have some code that does this and it works fine if you fix Triton to handle it. This is one of those fixes.

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

could you make a python level test?

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

BTW I thought we had ran into that when calling a function returning a tensor with early return within a loop. I can't find the patch though so maybe it never landed

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented May 1, 2026

If you make a call it's going to return a tensor? We have some code that does this and it works fine if you fix Triton to handle it. This is one of those fixes.

We need a python test. If a function is not inlined, I'm not confident that triton can return tensors properly.

@leijurv
Copy link
Copy Markdown
Contributor Author

leijurv commented May 1, 2026

Python-level test added. Lit test improved and gave an example that layouts are converted at the return as was brought up in the previous related PR #7996

@saagarjha saagarjha force-pushed the support-returning-tensors-in-triton-to-tritongpu branch 2 times, most recently from a0e7d43 to 16def4a Compare May 4, 2026 07:34
Comment thread python/test/unit/language/test_core.py Outdated
Comment on lines +1315 to +1333
@triton.jit
def tensor_return_inner():
if tl.program_id(0) == 0:
return tl.arange(0, 16)
else:
return tl.arange(0, 16) * 2


@triton.jit
def tensor_return_outer(x, BLOCK: tl.constexpr):
if tl.program_id(0) < 2:
tl.atomic_add(x + tl.arange(0, BLOCK), tensor_return_inner())


@pytest.mark.interpreter
def test_return_tensor(device):
x = torch.zeros(16, device=device)
tensor_return_outer[(3, )](x, x.numel())
assert torch.equal(x, torch.arange(x.numel(), device=device) * 3)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use noinline=True instead. I would expect those to be fully inlined so I don't get what this tests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Reverted back to that. We thought that using noinline=True might be unsupported so Saagar's approach was to force a function to not inline by having early returns in control flow. Much more straightforward to test noinline=True directly:

@triton.jit(noinline=True)
def noinline_load_block_fn(ptr, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
return tl.load(ptr + offsets)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think supporting noinline should be the goal here

@leijurv leijurv force-pushed the support-returning-tensors-in-triton-to-tritongpu branch from 16def4a to 3efb768 Compare May 4, 2026 17:19
Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. A tensor is lowered as a vector, and llvm can return vectors just fine.

I doubt the layout choices will be particularly good, but usually you would expect a performance trade off when marking something as explicitly noinline. Or, if it comes about due to a failure to inline, that shouldn't mean we fail to compile.

@leijurv leijurv requested a review from ThomasRaoux May 8, 2026 22:31
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@leijurv
Copy link
Copy Markdown
Contributor Author

leijurv commented May 12, 2026

merge?

@ThomasRaoux ThomasRaoux merged commit 215c162 into triton-lang:main May 12, 2026
38 of 45 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants