Support returning tensors in TritonToTritonGPU#10189
Conversation
Co-Authored-By: Saagar Jha <saagar@saagarjha.com>
Jokeren
left a comment
There was a problem hiding this comment.
I'm not clear about this PR. Do you need to handle returned tensors in LLVM lowering?
|
If you don't apply this PR then you TritonToTritonGPU will fail because the tensor that is returned will not have a layout. |
Why do we need to return tensors? Even if we did, I'm not sure if LLVM lowering handles return ops well. |
|
If you make a call it's going to return a tensor? We have some code that does this and it works fine if you fix Triton to handle it. This is one of those fixes. |
|
could you make a python level test? |
|
BTW I thought we had ran into that when calling a function returning a tensor with early return within a loop. I can't find the patch though so maybe it never landed |
We need a python test. If a function is not inlined, I'm not confident that triton can return tensors properly. |
|
Python-level test added. Lit test improved and gave an example that layouts are converted at the return as was brought up in the previous related PR #7996 |
a0e7d43 to
16def4a
Compare
| @triton.jit | ||
| def tensor_return_inner(): | ||
| if tl.program_id(0) == 0: | ||
| return tl.arange(0, 16) | ||
| else: | ||
| return tl.arange(0, 16) * 2 | ||
|
|
||
|
|
||
| @triton.jit | ||
| def tensor_return_outer(x, BLOCK: tl.constexpr): | ||
| if tl.program_id(0) < 2: | ||
| tl.atomic_add(x + tl.arange(0, BLOCK), tensor_return_inner()) | ||
|
|
||
|
|
||
| @pytest.mark.interpreter | ||
| def test_return_tensor(device): | ||
| x = torch.zeros(16, device=device) | ||
| tensor_return_outer[(3, )](x, x.numel()) | ||
| assert torch.equal(x, torch.arange(x.numel(), device=device) * 3) |
There was a problem hiding this comment.
can we use noinline=True instead. I would expect those to be fully inlined so I don't get what this tests
There was a problem hiding this comment.
Sure. Reverted back to that. We thought that using noinline=True might be unsupported so Saagar's approach was to force a function to not inline by having early returns in control flow. Much more straightforward to test noinline=True directly:
triton/python/test/unit/language/test_core.py
Lines 1315 to 1318 in 3efb768
There was a problem hiding this comment.
I think supporting noinline should be the goal here
16def4a to
3efb768
Compare
peterbell10
left a comment
There was a problem hiding this comment.
Makes sense to me. A tensor is lowered as a vector, and llvm can return vectors just fine.
I doubt the layout choices will be particularly good, but usually you would expect a performance trade off when marking something as explicitly noinline. Or, if it comes about due to a failure to inline, that shouldn't mean we fail to compile.
|
merge? |
The conversion patterns in
TritonToTritonGPU(TritonFuncOpPattern,TritonCallOpPattern,TritonReturnOpPattern) did not handlett.funcops returning tensors:TritonFuncOpPatternreused the originalFunctionTypeverbatim, and thetriton::FuncOplegality check inspected only argument types, so tensor result types came out without any layout encoding. There was already a fork of MLIR upstream's signature-conversion pattern inDialect/Triton/Transforms/FunctionTypeConversion.h(added because upstream is unaware oftt.call/tt.return) which handles inputs, results, one-to-many conversions, and arg-attribute remapping — so this PR drops those three custom patterns and reuses this one. The legality predicate is also extended to require encodings on result tensor types.Adds a regression test in
test/Conversion/triton_to_tritongpu.mlircovering att.funcreturning a tensor and att.callsite picking up the encoded result type.Making this PR on behalf of @saagarjha (he is waiting on #8913)
rules.
pre-commit run --from-ref origin/main --to-ref HEAD.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)