diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index dbb4c24a22ce..cfb31995bc3a 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -510,19 +510,6 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure, let hasVerifier = 1; } -// Cat is not pure because it may reorder elements. -def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, - SameTypeOperands, - SameOperandsAndResultElementType]> { - let summary = "concatenate 2 tensors"; - - let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; -} - def TT_JoinOp : TT_Op<"join", [ Pure, SameTypeOperands]> { let summary = "join two tensors along a new, minor dimension"; diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index db6241540314..0791d5490d47 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -250,8 +250,6 @@ SmallVector getMatrixOrder(unsigned rank, bool rowMajor); SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kContig); -bool isExpensiveCat(CatOp cat, Attribute targetEncoding); - // Return true if a view between the two types cannot be implemented as a no-op. bool isExpensiveView(Type srcType, Type dstType); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 9ea9e1994097..ffeb71f54c46 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -159,35 +159,6 @@ struct ArithConstantArrayOpConversion } }; -struct CatOpConversion : public ConvertOpToLLVMPattern { - using OpAdaptor = typename CatOp::Adaptor; - explicit CatOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = patternBenefitDefault) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} - LogicalResult - matchAndRewrite(CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto resultTy = cast(op.getType()); - unsigned elems = getTotalElemsPerThread(resultTy); - auto typeConverter = getTypeConverter(); - Type elemTy = typeConverter->convertType(resultTy.getElementType()); - SmallVector types(elems, elemTy); - // unpack input values - auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); - auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); - // concatenate (and potentially reorder) values - SmallVector retVals; - for (Value v : lhsVals) - retVals.push_back(v); - for (Value v : rhsVals) - retVals.push_back(v); - // pack and replace - Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); - rewriter.replaceOp(op, ret); - return success(); - } -}; struct JoinOpConversion : public ConvertOpToLLVMPattern { using OpAdaptor = typename JoinOp::Adaptor; explicit JoinOpConversion(LLVMTypeConverter &typeConverter, @@ -590,7 +561,6 @@ void mlir::triton::populateViewOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add( diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index c5bdc9e7b34c..f25f1754d445 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -279,51 +279,6 @@ struct TritonDotPattern : public OpConversionPattern { } }; -struct TritonCatPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // The cat op satisfy two conditions: - // 1. output.numel = lhs.numel + rhs.numel - // 2. output.total_elems_per_thread = - // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) - // For now, this behaves like generic, but this - // will evolve when we add support for `can_reorder=False`. - auto retType = cast( - this->getTypeConverter()->convertType(op.getType())); - auto retEncoding = - cast(retType.getEncoding()); - auto lhsType = adaptor.getLhs().getType(); - auto rhsType = adaptor.getRhs().getType(); - auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); - auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); - auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); - auto retShape = retType.getShape(); - auto retOrder = retEncoding.getOrder(); - auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); - auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); - // Get new retSizePerThread if ret elems per thread is not enough. - // We have to round it up to the next power of 2 due to triton's tensor size - // constraint. - auto newRetTotalElemsPerThread = - nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); - auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread()); - newRetSizePerThread[retOrder[0]] *= - newRetTotalElemsPerThread / retTotalElemsPerThread; - triton::gpu::BlockedEncodingAttr newRetEncoding = - triton::gpu::BlockedEncodingAttr::get( - getContext(), newRetSizePerThread, retThreadsPerWarp, - retWarpsPerCTA, retOrder, retEncoding.getCGALayout()); - auto newRetType = retType.cloneWithEncoding(newRetEncoding); - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, newRetType, adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonJoinOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -567,7 +522,6 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, TritonBroadcastPattern, - TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp b/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp index bff4e64a4bbc..d84df8169811 100644 --- a/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp +++ b/lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp @@ -76,7 +76,7 @@ LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op, } bool encodingsMayVary(Operation *op) { - return isa(op); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 4312e6c643fb..0140907ca8bb 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -360,18 +360,6 @@ SmallVector orderPerDimImpl(const LinearLayout &ll, return order.takeVector(); } -bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to - // do convert encoding that goes through shared memory anyway. So we - // consider it as expensive. - RankedTensorType tensorTy = cat.getType(); - auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); - auto shape = tensorTy.getShape(); - auto newTotalElemsPerThread = - gpu::getTotalElemsPerThread(targetEncoding, shape); - return newTotalElemsPerThread < totalElemsPerThread; -} - static LogicalResult verifyLayoutOrder(function_ref emitError, ArrayRef order) { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index fd8008ff9203..32b5ab7a6054 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -342,16 +342,6 @@ struct CanonicalizeConvertFromConvert return success(); } - // cvt(cat) -> cat - if (auto cat = dyn_cast(arg)) { - if (isExpensiveCat(cat, op.getType().getEncoding())) - return failure(); - - rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), - cat.getOperands()); - return success(); - } - // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index ddab8f75d5b3..27365047ac7f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -614,8 +614,6 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { return true; if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) - return triton::gpu::isExpensiveCat(cast(op), targetEncoding); if (isa(op)) return true; @@ -626,9 +624,6 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { } bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { - if (isa(op)) - return !triton::gpu::isExpensiveCat(cast(op), - targetEncoding); if (auto convert = dyn_cast(op)) { if (mlir::isa(targetEncoding)) { auto srcEncoding = convert.getSrc().getType().getEncoding(); @@ -932,8 +927,6 @@ LogicalResult getConvertBackwardSlice( continue; if (stopPropagation && stopPropagation(definingOp)) continue; - if (isa(definingOp)) - return failure(); if (auto gather = dyn_cast(definingOp)) { // Specially handle gather since its transfer function only applies // between its index operand and result. diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 4dadf5d2ad25..481d9ab5f677 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -579,11 +579,6 @@ void init_gluon_ir(py::module &&m) { *mask); } }) - .def("create_cat", - [](GluonOpBuilder &self, Value &lhs, Value &rhs, - Type retType) -> Value { - return self.create(retType, lhs, rhs); - }) .def("create_fp4_to_fp", [](GluonOpBuilder &self, Value src, Type elemType, int axis) -> Value { diff --git a/python/src/ir.cc b/python/src/ir.cc index bf662affa343..7b02040d3cf2 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1551,18 +1551,6 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &arg, int axis) -> Value { return self.create(arg, axis); }) - .def("create_cat", - [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { - auto lhsType = dyn_cast(lhs.getType()); - auto rhsType = dyn_cast(rhs.getType()); - if (!(lhsType.getShape().size() == 1 && - rhsType.getShape().size() == 1)) - throw std::invalid_argument( - "shape not supported by cat. Expecting rank-1 inputs"); - std::vector shape{lhsType.getShape()[0] + - rhsType.getShape()[0]}; - return self.create(lhsType.clone(shape), lhs, rhs); - }) .def("create_join", [](TritonOpBuilder &self, Value &a, Value &b) -> Value { return self.create(a, b); diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index f81f98d585ba..5f3ebb626695 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -469,13 +469,6 @@ def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> T handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr) return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout) - def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool, layout) -> TensorTy: - _check(layout is not None, lambda: "cat requires a destination layout") - _check(can_reorder, lambda: "current implementation of `cat` always may reorder elements") - _check(len(lhs.shape) == 1, lambda: "cat requires a rank-1 input") - ret_type = ttgl.distributed_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]], layout) - return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle, ret_type.to_ir(self.builder)), ret_type) - def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: _check(isinstance(src.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {src.type!r}") _check(isinstance(index.type, ttgl.distributed_type), diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 29f6a367f368..2699c0c6d038 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1782,7 +1782,7 @@ def permute(input, *dims, _semantic=None): @builtin -def cat(input, other, can_reorder=False, _semantic=None): +def cat(input, other, can_reorder=False, dim=0, _semantic=None): """ Concatenate the given blocks @@ -1790,12 +1790,30 @@ def cat(input, other, can_reorder=False, _semantic=None): :type input: Tensor :param other: The second input tensor. :type other: Tensor - :param reorder: Compiler hint. If true, the compiler is - allowed to reorder elements while concatenating inputs. Only use if the - order does not matter (e.g., result is only used in reduction ops). - Current implementation of `cat` supports only can_reorder=True. - """ - return _semantic.cat(input, other, can_reorder) + :param can_reorder: Deprecated option. Elements are never reordered. + :type can_reorder: bool + :param dim: The dimension to concatenate along. + :type dim: int + """ + rank = len(input.shape) + assert rank == len(other.shape), f"tensors must have the same rank, got {rank} and {len(other.shape)}" + assert all(input.shape[i] == other.shape[i] for i in builtins.range(rank) if i != + dim), f"tensor dims must match except in the concat dimension {dim}, got {input.shape} and {other.shape}" + + order = list(builtins.range(rank)) + order[dim], order[-1] = order[-1], order[dim] + inv_order = [order.index(i) for i in builtins.range(rank)] + + a = permute(input, order, _semantic=_semantic) + b = permute(other, order, _semantic=_semantic) + + leading = a.shape[:-1] + a = reshape(a, (math.prod(leading), a.shape[-1]), _semantic=_semantic) + b = reshape(b, (math.prod(leading), b.shape[-1]), _semantic=_semantic) + + c = join(a, b, _semantic=_semantic) + c = reshape(c, leading + [a.shape[-1] + b.shape[-1]], _semantic=_semantic) + return permute(c, inv_order, _semantic=_semantic) @builtin diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 68491acabdde..6db9ffb25fec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -647,12 +647,6 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: ret_ty = tl.block_type(input.type.scalar, dst_shape) return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty) - def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: - assert can_reorder, "current implementation of `cat` always may reorder elements" - assert len(lhs.shape) == 1 - ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) - return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) - def join(self, a: TensorTy, b: TensorTy) -> TensorTy: a, b = self.broadcast_impl_value(a, b) diff --git a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py index 2b946ee3bf9f..e4a2fe473e28 100644 --- a/python/triton/tools/triton_to_gluon_translater/translator_helpers.py +++ b/python/triton/tools/triton_to_gluon_translater/translator_helpers.py @@ -519,30 +519,6 @@ def tl_trans(value, *dims, _semantic=None): return value.trans(*dims, _semantic=_semantic) -@ttgl._core.builtin -def cat(input, other, can_reorder=False, layout=None, _semantic=None): - """ - Concatenate the two tensors. - - Args: - input (tensor): The first input tensor. - other (tensor): The second input tensor. - can_reorder (bool): Compiler hint. If true, the compiler is allowed to reorder elements while concatenating inputs. Only use if the order does not matter (e.g., result is only used in reduction ops). Current implementation of `cat` supports only can_reorder=True. - layout (DistributedLayout): The destination layout of the output tensor. - - Returns: - tensor: The concatenated tensor. - """ - can_reorder = ttgl._core._unwrap_if_constexpr(can_reorder) - layout = ttgl._core._unwrap_if_constexpr(layout) - return _semantic.cat(input, other, can_reorder, layout) - - -@gluon.jit -def tl_cat(lhs, rhs, can_reorder=False): - return cat(lhs, rhs, can_reorder, layout=default_blocked_layout([lhs.shape[0] + rhs.shape[0]], ttgl.num_warps())) - - @gluon.jit def reset_to_default_layout(value): ty: ttgl.constexpr = value.type diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index 3e25166a1ca4..feb576078742 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -295,6 +295,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { @@ -304,7 +306,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2> + %tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3> + %6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 3d237f39866b..eed19d6d8d00 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -334,6 +334,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: join_cat_transitive_nonneg tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { @@ -343,7 +345,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> - %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2> + %tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3> + %6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked> %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 5d975534e948..4f2910543022 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1245,7 +1245,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32> // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} - %6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32> + %tmp = tt.join %5, %5 : tensor<4x2xi32> -> tensor<4x2x2xi32> + %6 = tt.reshape %tmp : tensor<4x2x2xi32> -> tensor<8x2xi32> // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %2, %6 : tensor<8x2xi32> diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index b1057774cbbd..c702e228b66e 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -568,7 +568,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( // Ops with actually changing/variable input/output ranges. if (llvm::isa(op)) { + SplatOp, ExpandDimsOp, JoinOp, GatherOp>(op)) { SmallVector argConstIntRanges; for (const auto &r : argIntValueRanges) { if (r.isUninitialized()) { @@ -583,7 +583,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( return inferResultRangesUnaryOpForwardArgRange(op, argConstIntRanges, joinCallback); }) - .Case([&](auto joinOp) { + .Case([&](auto joinOp) { return inferResultRangesBinaryOpUnionArgRanges( joinOp, argConstIntRanges, joinCallback); })