Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,19 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
let hasVerifier = 1;
}

// Cat is not pure because it may reorder elements.
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
SameTypeOperands,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";

let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);

let results = (outs TT_Tensor:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
}

def TT_JoinOp : TT_Op<"join", [
Pure, SameTypeOperands]> {
let summary = "join two tensors along a new, minor dimension";
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);

Expand Down
30 changes: 30 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,35 @@ struct ArithConstantArrayOpConversion
}
};

struct CatOpConversion : public ConvertOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = patternBenefitDefault)
: ConvertOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<RankedTensorType>(op.getType());
unsigned elems = getTotalElemsPerThread(resultTy);
auto typeConverter = getTypeConverter();
Type elemTy = typeConverter->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter);
auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter);
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
retVals.push_back(v);
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
};
struct JoinOpConversion : public ConvertOpToLLVMPattern<JoinOp> {
using OpAdaptor = typename JoinOp::Adaptor;
explicit JoinOpConversion(LLVMTypeConverter &typeConverter,
Expand Down Expand Up @@ -561,6 +590,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
patterns.add<UnsplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantArrayOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<JoinOpConversion>(typeConverter, benefit);
patterns.add<SplitOpConversion>(typeConverter, benefit);
patterns.add<MemDescTransOpConversion, MemDescReshapeOpConversion>(
Expand Down
46 changes: 46 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,51 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};

struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The cat op satisfy two conditions:
// 1. output.numel = lhs.numel + rhs.numel
// 2. output.total_elems_per_thread =
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
// For now, this behaves like generic, but this
// will evolve when we add support for `can_reorder=False`.
auto retType = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getType()));
auto retEncoding =
cast<triton::gpu::BlockedEncodingAttr>(retType.getEncoding());
auto lhsType = adaptor.getLhs().getType();
auto rhsType = adaptor.getRhs().getType();
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
auto retShape = retType.getShape();
auto retOrder = retEncoding.getOrder();
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
// Get new retSizePerThread if ret elems per thread is not enough.
// We have to round it up to the next power of 2 due to triton's tensor size
// constraint.
auto newRetTotalElemsPerThread =
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread());
newRetSizePerThread[retOrder[0]] *=
newRetTotalElemsPerThread / retTotalElemsPerThread;
triton::gpu::BlockedEncodingAttr newRetEncoding =
triton::gpu::BlockedEncodingAttr::get(
getContext(), newRetSizePerThread, retThreadsPerWarp,
retWarpsPerCTA, retOrder, retEncoding.getCGALayout());
auto newRetType = retType.cloneWithEncoding(newRetEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
op, newRetType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};

struct TritonJoinOpPattern : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -522,6 +567,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::UnsplatOp>,
GenericOpPattern<triton::AddPtrOp>,
TritonBroadcastPattern,
TritonCatPattern,
TritonJoinOpPattern,
TritonSplitOpPattern,
GenericOpPattern<triton::ClampFOp>,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Gluon/Transforms/InferLayoutUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ LayoutInfo combineInfo(LayoutInfo lhs, LayoutInfo rhs, Operation *op,
}

bool encodingsMayVary(Operation *op) {
return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp,
return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp, triton::CatOp,
triton::TransOp>(op);
}

Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,18 @@ SmallVector<unsigned> orderPerDimImpl(const LinearLayout &ll,
return order.takeVector();
}

bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
// If the new elements per thread is less than the old one, we will need to
// do convert encoding that goes through shared memory anyway. So we
// consider it as expensive.
RankedTensorType tensorTy = cat.getType();
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto newTotalElemsPerThread =
gpu::getTotalElemsPerThread(targetEncoding, shape);
return newTotalElemsPerThread < totalElemsPerThread;
}

static LogicalResult
verifyLayoutOrder(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<unsigned> order) {
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,16 @@ struct CanonicalizeConvertFromConvert
return success();
}

// cvt(cat) -> cat
if (auto cat = dyn_cast<CatOp>(arg)) {
if (isExpensiveCat(cat, op.getType().getEncoding()))
return failure();

rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
return success();
}

// cvt(cvt(x, type1), type2) -> cvt(x, type2)
if (auto cvt = dyn_cast<ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,8 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
return true;
if (isa<triton::LoadOp, triton::StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<triton::CatOp>(op))
return triton::gpu::isExpensiveCat(cast<triton::CatOp>(op), targetEncoding);
if (isa<triton::gpu::AsyncCopyGlobalToLocalOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DotOp>(op))
return true;
Expand All @@ -624,6 +626,9 @@ bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
}

bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
if (isa<triton::CatOp>(op))
return !triton::gpu::isExpensiveCat(cast<triton::CatOp>(op),
targetEncoding);
if (auto convert = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (mlir::isa<triton::gpu::NvidiaMmaEncodingAttr>(targetEncoding)) {
auto srcEncoding = convert.getSrc().getType().getEncoding();
Expand Down Expand Up @@ -927,6 +932,8 @@ LogicalResult getConvertBackwardSlice(
continue;
if (stopPropagation && stopPropagation(definingOp))
continue;
if (isa<triton::CatOp>(definingOp))
return failure();
if (auto gather = dyn_cast<GatherOp>(definingOp)) {
// Specially handle gather since its transfer function only applies
// between its index operand and result.
Expand Down
5 changes: 5 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ void init_gluon_ir(py::module &&m) {
*mask);
}
})
.def("create_cat",
[](GluonOpBuilder &self, Value &lhs, Value &rhs,
Type retType) -> Value {
return self.create<triton::CatOp>(retType, lhs, rhs);
})
.def("create_fp4_to_fp",
[](GluonOpBuilder &self, Value src, Type elemType,
int axis) -> Value {
Expand Down
12 changes: 12 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,18 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self, Value &arg, int axis) -> Value {
return self.create<ExpandDimsOp>(arg, axis);
})
.def("create_cat",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
if (!(lhsType.getShape().size() == 1 &&
rhsType.getShape().size() == 1))
throw std::invalid_argument(
"shape not supported by cat. Expecting rank-1 inputs");
std::vector<int64_t> shape{lhsType.getShape()[0] +
rhsType.getShape()[0]};
return self.create<CatOp>(lhsType.clone(shape), lhs, rhs);
})
.def("create_join",
[](TritonOpBuilder &self, Value &a, Value &b) -> Value {
return self.create<JoinOp>(a, b);
Expand Down
7 changes: 7 additions & 0 deletions python/triton/experimental/gluon/language/_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,13 @@ def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> T
handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)

def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool, layout) -> TensorTy:
_check(layout is not None, lambda: "cat requires a destination layout")
_check(can_reorder, lambda: "current implementation of `cat` always may reorder elements")
_check(len(lhs.shape) == 1, lambda: "cat requires a rank-1 input")
ret_type = ttgl.distributed_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]], layout)
return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle, ret_type.to_ir(self.builder)), ret_type)

def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
_check(isinstance(src.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {src.type!r}")
_check(isinstance(index.type, ttgl.distributed_type),
Expand Down
32 changes: 7 additions & 25 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,38 +1795,20 @@ def permute(input, *dims, _semantic=None):


@builtin
def cat(input, other, can_reorder=False, dim=0, _semantic=None):
def cat(input, other, can_reorder=False, _semantic=None):
"""
Concatenate the given blocks

:param input: The first input tensor.
:type input: Tensor
:param other: The second input tensor.
:type other: Tensor
:param can_reorder: Deprecated option. Elements are never reordered.
:type can_reorder: bool
:param dim: The dimension to concatenate along.
:type dim: int
"""
rank = len(input.shape)
assert rank == len(other.shape), f"tensors must have the same rank, got {rank} and {len(other.shape)}"
assert all(input.shape[i] == other.shape[i] for i in builtins.range(rank) if i !=
dim), f"tensor dims must match except in the concat dimension {dim}, got {input.shape} and {other.shape}"

order = list(builtins.range(rank))
order[dim], order[-1] = order[-1], order[dim]
inv_order = [order.index(i) for i in builtins.range(rank)]

a = permute(input, order, _semantic=_semantic)
b = permute(other, order, _semantic=_semantic)

leading = a.shape[:-1]
a = reshape(a, (math.prod(leading), a.shape[-1]), _semantic=_semantic)
b = reshape(b, (math.prod(leading), b.shape[-1]), _semantic=_semantic)

c = join(a, b, _semantic=_semantic)
c = reshape(c, leading + [a.shape[-1] + b.shape[-1]], _semantic=_semantic)
return permute(c, inv_order, _semantic=_semantic)
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs. Only use if the
order does not matter (e.g., result is only used in reduction ops).
Current implementation of `cat` supports only can_reorder=True.
"""
return _semantic.cat(input, other, can_reorder)


@builtin
Expand Down
6 changes: 6 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,12 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty)

def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)
Comment thread
ThomasRaoux marked this conversation as resolved.

def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
a, b = self.broadcast_impl_value(a, b)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,30 @@ def tl_trans(value, *dims, _semantic=None):
return value.trans(*dims, _semantic=_semantic)


@ttgl._core.builtin
def cat(input, other, can_reorder=False, layout=None, _semantic=None):
"""
Concatenate the two tensors.

Args:
input (tensor): The first input tensor.
other (tensor): The second input tensor.
can_reorder (bool): Compiler hint. If true, the compiler is allowed to reorder elements while concatenating inputs. Only use if the order does not matter (e.g., result is only used in reduction ops). Current implementation of `cat` supports only can_reorder=True.
layout (DistributedLayout): The destination layout of the output tensor.

Returns:
tensor: The concatenated tensor.
"""
can_reorder = ttgl._core._unwrap_if_constexpr(can_reorder)
layout = ttgl._core._unwrap_if_constexpr(layout)
return _semantic.cat(input, other, can_reorder, layout)


@gluon.jit
def tl_cat(lhs, rhs, can_reorder=False):
return cat(lhs, rhs, can_reorder, layout=default_blocked_layout([lhs.shape[0] + rhs.shape[0]], ttgl.num_warps()))


@gluon.jit
def reset_to_default_layout(value):
ty: ttgl.constexpr = value.type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// COMMON-LABEL: join_cat_transitive_nonneg
tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
Expand All @@ -306,9 +304,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked>
%tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2>
%tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3>
%6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
%7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked>
%zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked>
%ones = arith.constant dense<1> : tensor<8x1xi32, #blocked>
Expand Down
6 changes: 1 addition & 5 deletions test/TritonGPU/amd/amd-convert-buffer-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 2, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// COMMON-LABEL: join_cat_transitive_nonneg
tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
Expand All @@ -345,9 +343,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
%3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1>
%4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1>
%5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked>
%tmp = tt.join %5, %5 : tensor<4x2xi32, #blocked> -> tensor<4x2x2xi32, #blocked2>
%tmp1 = tt.reshape %tmp : tensor<4x2x2xi32, #blocked2> -> tensor<8x2xi32, #blocked3>
%6 = ttg.convert_layout %tmp1 : tensor<8x2xi32, #blocked3> -> tensor<8x2xi32, #blocked>
%6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked>
%7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked>
%zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked>
%ones = arith.constant dense<1> : tensor<8x1xi32, #blocked>
Expand Down
3 changes: 1 addition & 2 deletions test/TritonGPU/amd/amd-range-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
%5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32>
// expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}}
// expected-remark@+1 {{non-neg}}
%tmp = tt.join %5, %5 : tensor<4x2xi32> -> tensor<4x2x2xi32>
%6 = tt.reshape %tmp : tensor<4x2x2xi32> -> tensor<8x2xi32>
%6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32>
// expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}}
// expected-remark@+1 {{non-neg}}
%7 = arith.addi %2, %6 : tensor<8x2xi32>
Expand Down
Loading
Loading