diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index cfb31995bc3a..dbb4c24a22ce 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -510,6 +510,19 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure, let hasVerifier = 1; } +// Cat is not pure because it may reorder elements. +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + def TT_JoinOp : TT_Op<"join", [ Pure, SameTypeOperands]> { let summary = "join two tensors along a new, minor dimension"; diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 0791d5490d47..db6241540314 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -250,6 +250,8 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor); SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kContig); +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + // Return true if a view between the two types cannot be implemented as a no-op. bool isExpensiveView(Type srcType, Type dstType); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index ffeb71f54c46..9ea9e1994097 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -159,6 +159,35 @@ struct ArithConstantArrayOpConversion } }; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; struct JoinOpConversion : public ConvertOpToLLVMPattern { using OpAdaptor = typename JoinOp::Adaptor; explicit JoinOpConversion(LLVMTypeConverter &typeConverter, @@ -561,6 +590,7 @@ void mlir::triton::populateViewOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add( diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index f25f1754d445..c5bdc9e7b34c 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -279,6 +279,51 @@ struct TritonDotPattern : public OpConversionPattern { } }; +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread()); + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCGALayout()); + auto newRetType = retType.cloneWithEncoding(newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + struct TritonJoinOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -522,6 +567,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, TritonBroadcastPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp b/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp index d84df8169811..bff4e64a4bbc 100644 --- a/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp +++ b/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp @@ -76,7 +76,7 @@ LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op, } bool encodingsMayVary(Operation *op) { - return isa(op); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0140907ca8bb..4312e6c643fb 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -360,6 +360,18 @@ SmallVector orderPerDimImpl(const LinearLayout &ll, return order.takeVector(); } +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape); + return newTotalElemsPerThread < totalElemsPerThread; +} + static LogicalResult verifyLayoutOrder(function_ref emitError, ArrayRef order) { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 32b5ab7a6054..fd8008ff9203 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -342,6 +342,16 @@ struct CanonicalizeConvertFromConvert return success(); } + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index d1241908993a..61d15d1d6e34 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -614,6 +614,8 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { return true; if (isa(op)) return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) return true; @@ -624,6 +626,9 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { } bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); if (auto convert = dyn_cast(op)) { if (mlir::isa(targetEncoding)) { auto srcEncoding = convert.getSrc().getType().getEncoding(); @@ -927,6 +932,8 @@ LogicalResult getConvertBackwardSlice( continue; if (stopPropagation && stopPropagation(definingOp)) continue; + if (isa(definingOp)) + return failure(); if (auto gather = dyn_cast(definingOp)) { // Specially handle gather since its transfer function only applies // between its index operand and result. diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 481d9ab5f677..4dadf5d2ad25 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -579,6 +579,11 @@ void init_gluon_ir(py::module &&m) { *mask); } }) + .def("create_cat", + [](GluonOpBuilder &self, Value &lhs, Value &rhs, + Type retType) -> Value { + return self.create(retType, lhs, rhs); + }) .def("create_fp4_to_fp", [](GluonOpBuilder &self, Value src, Type elemType, int axis) -> Value { diff --git a/python/src/ir.cc b/python/src/ir.cc index 7b02040d3cf2..bf662affa343 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1551,6 +1551,18 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &arg, int axis) -> Value { return self.create(arg, axis); }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create(lhsType.clone(shape), lhs, rhs); + }) .def("create_join", [](TritonOpBuilder &self, Value &a, Value &b) -> Value { return self.create(a, b); diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 5f3ebb626695..f81f98d585ba 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -469,6 +469,13 @@ def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> T handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr) return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout) + def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool, layout) -> TensorTy: + _check(layout is not None, lambda: "cat requires a destination layout") + _check(can_reorder, lambda: "current implementation of `cat` always may reorder elements") + _check(len(lhs.shape) == 1, lambda: "cat requires a rank-1 input") + ret_type = ttgl.distributed_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]], layout) + return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle, ret_type.to_ir(self.builder)), ret_type) + def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: _check(isinstance(src.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {src.type!r}") _check(isinstance(index.type, ttgl.distributed_type), diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e4115e82e111..6401d9e25ca6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1795,7 +1795,7 @@ def permute(input, *dims, _semantic=None): @builtin -def cat(input, other, can_reorder=False, dim=0, _semantic=None): +def cat(input, other, can_reorder=False, _semantic=None): """ Concatenate the given blocks @@ -1803,30 +1803,12 @@ def cat(input, other, can_reorder=False, dim=0, _semantic=None): :type input: Tensor :param other: The second input tensor. :type other: Tensor - :param can_reorder: Deprecated option. Elements are never reordered. - :type can_reorder: bool - :param dim: The dimension to concatenate along. - :type dim: int - """ - rank = len(input.shape) - assert rank == len(other.shape), f"tensors must have the same rank, got {rank} and {len(other.shape)}" - assert all(input.shape[i] == other.shape[i] for i in builtins.range(rank) if i != - dim), f"tensor dims must match except in the concat dimension {dim}, got {input.shape} and {other.shape}" - - order = list(builtins.range(rank)) - order[dim], order[-1] = order[-1], order[dim] - inv_order = [order.index(i) for i in builtins.range(rank)] - - a = permute(input, order, _semantic=_semantic) - b = permute(other, order, _semantic=_semantic) - - leading = a.shape[:-1] - a = reshape(a, (math.prod(leading), a.shape[-1]), _semantic=_semantic) - b = reshape(b, (math.prod(leading), b.shape[-1]), _semantic=_semantic) - - c = join(a, b, _semantic=_semantic) - c = reshape(c, leading + [a.shape[-1] + b.shape[-1]], _semantic=_semantic) - return permute(c, inv_order, _semantic=_semantic) + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops). + Current implementation of `cat` supports only can_reorder=True. + """ + return _semantic.cat(input, other, can_reorder) @builtin diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6db9ffb25fec..68491acabdde 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -647,6 +647,12 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: ret_ty = tl.block_type(input.type.scalar, dst_shape) return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty) + def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) + def join(self, a: TensorTy, b: TensorTy) -> TensorTy: a, b = self.broadcast_impl_value(a, b) diff --git a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py index e4a2fe473e28..2b946ee3bf9f 100644 --- a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py @@ -519,6 +519,30 @@ def tl_trans(value, *dims, _semantic=None): return value.trans(*dims, _semantic=_semantic) +@ttgl._core.builtin +def cat(input, other, can_reorder=False, layout=None, _semantic=None): + """ + Concatenate the two tensors. + + Args: + input (tensor): The first input tensor. + other (tensor): The second input tensor. + can_reorder (bool): Compiler hint. If true, the compiler is allowed to reorder elements while concatenating inputs. Only use if the order does not matter (e.g., result is only used in reduction ops). Current implementation of `cat` supports only can_reorder=True. + layout (DistributedLayout): The destination layout of the output tensor. + + Returns: + tensor: The concatenated tensor. + """ + can_reorder = ttgl._core._unwrap_if_constexpr(can_reorder) + layout = ttgl._core._unwrap_if_constexpr(layout) + return _semantic.cat(input, other, can_reorder, layout) + + +@gluon.jit +def tl_cat(lhs, rhs, can_reorder=False): + return cat(lhs, rhs, can_reorder, layout=default_blocked_layout([lhs.shape[0] + rhs.shape[0]], ttgl.num_warps())) + + @gluon.jit def reset_to_default_layout(value): ty: ttgl.constexpr = value.type diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index feb576078742..3e25166a1ca4 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -295,8 +295,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { @@ -306,9 +304,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2> - %tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3> - %6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index eed19d6d8d00..3d237f39866b 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -334,8 +334,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { @@ -345,9 +343,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2> - %tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3> - %6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 4f2910543022..5d975534e948 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1245,8 +1245,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32> // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} - %tmp = tt.join %5, %5 : tensor<4x2xi32> -> tensor<4x2x2xi32> - %6 = tt.reshape %tmp : tensor<4x2x2xi32> -> tensor<8x2xi32> + %6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32> // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %2, %6 : tensor<8x2xi32> diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index c702e228b66e..b1057774cbbd 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -568,7 +568,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( // Ops with actually changing/variable input/output ranges. if (llvm::isa(op)) { + SplatOp, ExpandDimsOp, JoinOp, CatOp, GatherOp>(op)) { SmallVector argConstIntRanges; for (const auto &r : argIntValueRanges) { if (r.isUninitialized()) { @@ -583,7 +583,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( return inferResultRangesUnaryOpForwardArgRange(op, argConstIntRanges, joinCallback); }) - .Case([&](auto joinOp) { + .Case([&](auto joinOp) { return inferResultRangesBinaryOpUnionArgRanges( joinOp, argConstIntRanges, joinCallback); })