Skip to content

Commit

Permalink
[BACKEND] Fix SelectOp conversion with scalar conditional (#1880)
Browse files Browse the repository at this point in the history
`arith::SelectOp` supports a form where the condition argument is a
scalar and the result is a tensor. This isn't generated from `tl.where`,
but can still show up from canonicalization of `scf.if`.

Currently if this happens, the conversion to gpu IR will fail because
`triton_gpu.select` doesn't support this form. For example,
```python
import triton
import triton.language as tl
import torch

@triton.jit
def _triton_test(
    in_ptr, out_ptr, cond, XBLOCK: tl.constexpr
):
    xindex = tl.arange(0, XBLOCK)
    tmp = tl.load(in_ptr + xindex)
    if cond:
        a = tl.zeros_like(tmp)
    else:
        a = tmp
    tl.store(out_ptr + xindex, a)

t = torch.randn(128, device="cuda")
out = torch.empty(128, device="cuda")
_triton_test[(1,)](t, out, True, t.numel())
```

Fails with the error
```
error: 'triton_gpu.select' op requires the same shape for all operands and results
```

Co-authored-by: Keren Zhou <[email protected]>
  • Loading branch information
peterbell10 and Jokeren authored Jul 6, 2023
1 parent a1301c9 commit b45baa5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
22 changes: 18 additions & 4 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,24 @@ class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, adaptor.getCondition(),
adaptor.getTrueValue(), adaptor.getFalseValue()),
adaptor.getAttributes());

Value cond = adaptor.getCondition();
if (llvm::isa<RankedTensorType>(retType) &&
!llvm::isa<TensorType>(cond.getType())) {
// triton_gpu.select doesn't support scalar condition values, so add a
// splat
auto retTypeTensor = llvm::cast<RankedTensorType>(retType);
auto retShape = retTypeTensor.getShape();
auto retEncoding = retTypeTensor.getEncoding();
Type condTy =
RankedTensorType::get(retShape, cond.getType(), retEncoding);
cond = rewriter.create<triton::SplatOp>(op.getLoc(), condTy, cond);
}

addNamedAttrs(
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()),
adaptor.getAttributes());
return success();
}
};
Expand Down
21 changes: 21 additions & 0 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,24 @@ tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {

tt.return
}


// -----

tt.func public @select_op(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i1) attributes {noinline = false} {
// CHECK-LABEL: select_op
%cst = arith.constant dense<0.000000e+00> : tensor<128xf32>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32>

// CHECK: %[[splat:.*]] = tt.splat %arg2 : (i1) -> tensor<128xi1, #blocked>
// CHECK-NEXT: %{{.*}} = "triton_gpu.select"(%[[splat]], %{{.*}}, %{{.*}}) : (tensor<128xi1, #blocked>, tensor<128xf32, #blocked>, tensor<128xf32, #blocked>) -> tensor<128xf32, #blocked>
%4 = arith.select %arg2, %cst, %3 : tensor<128xf32>

%5 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%6 = tt.addptr %5, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
tt.store %6, %4 {cache = 1 : i32, evict = 1 : i32} : tensor<128xf32>
tt.return
}

0 comments on commit b45baa5

Please sign in to comment.