Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00>
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x4x1xf32>
// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32
Expand All @@ -241,15 +240,18 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
// CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
// CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
// CHECK: %[[R0:.+]]:16 = vector.to_elements %[[R0_CAST]] : vector<4x1x4x1xf32>
// CHECK: %[[R1:.+]]:16 = vector.to_elements %[[R1_CAST]] : vector<4x1x4x1xf32>
// CHECK: %[[INS:.+]] = vector.from_elements
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
// CHECK: return %[[R]]

// -----
Expand Down Expand Up @@ -403,28 +405,23 @@ builtin.module attributes { transform.with_named_sequence } {
}
}

// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2x3x4x1x4x1xf32>
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
// CHECK: vector.extract %[[C_SIMT]][0, 0]
// CHECK: amdgpu.mfma
// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
// CHECK: vector.extract %[[C_SIMT]][0, 1]
// CHECK: amdgpu.mfma
// CHECK: %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1]
// CHECK: vector.extract %[[C_SIMT]][0, 2]
// CHECK: amdgpu.mfma
// CHECK: %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [0, 2]
// CHECK: vector.extract %[[C_SIMT]][1, 0]
// CHECK: amdgpu.mfma
// CHECK: %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 0]
// CHECK: vector.extract %[[C_SIMT]][1, 1]
// CHECK: amdgpu.mfma
// CHECK: %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [1, 1]
// CHECK: vector.extract %[[C_SIMT]][1, 2]
// CHECK: amdgpu.mfma
// CHECK: %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [1, 2]
// CHECK: iree_vector_ext.to_simd %[[INS5]]
// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
// CHECK: vector.extract %[[C_SIMT]][0, 0]
// CHECK: amdgpu.mfma
// CHECK: vector.extract %[[C_SIMT]][0, 1]
// CHECK: amdgpu.mfma
// CHECK: vector.extract %[[C_SIMT]][0, 2]
// CHECK: amdgpu.mfma
// CHECK: vector.extract %[[C_SIMT]][1, 0]
// CHECK: amdgpu.mfma
// CHECK: vector.extract %[[C_SIMT]][1, 1]
// CHECK: amdgpu.mfma
// CHECK: vector.extract %[[C_SIMT]][1, 2]
// CHECK: amdgpu.mfma
// CHECK-COUNT-6: vector.to_elements {{.*}} : vector<4x1x4x1xf32>
// CHECK: %[[INS:.+]] = vector.from_elements
// CHECK: iree_vector_ext.to_simd %[[INS]]

// -----

Expand Down Expand Up @@ -495,15 +492,17 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mmt
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x4x1xf32>
// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
// CHECK: vector.extract %[[B_SIMT]][0, 0]
// CHECK: amdgpu.mfma
// CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
// CHECK: vector.extract %[[B_SIMT]][1, 0]
// CHECK: amdgpu.mfma
// CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1]
// CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
// CHECK: %[[R0:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
// CHECK: %[[R1:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
// CHECK: %[[INS:.+]] = vector.from_elements
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
// CHECK: iree_vector_ext.to_simd %[[INS]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>

// -----

Expand Down Expand Up @@ -838,6 +837,7 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[B_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
// CHECK: %[[MFMA_1:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_1]] + %[[MFMA_0]]
// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
// CHECK: %[[MFMA_1_CAST:.*]] = vector.shape_cast %[[MFMA_1]] : vector<4xf32> to vector<1x1x4x1xf32>
// CHECK: %[[B_CAST_2:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
// CHECK: %[[C_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x4x1xf32> to vector<4xf32>
// CHECK: %[[MFMA_2:.*]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST_2]] + %[[C_CAST_1]]
Expand All @@ -846,6 +846,10 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[MFMA_3:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_3]] + %[[MFMA_2]]
// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_3]] : vector<4xf32> to vector<1x1x4x1xf32>
// CHECK: %[[B_OUT:.*]] = vector.insert %[[R_CAST]]
// CHECK: %[[R0:.+]]:4 = vector.to_elements %[[MFMA_1_CAST]] : vector<1x1x4x1xf32>
// CHECK: %[[R1:.+]]:4 = vector.to_elements %[[R_CAST]] : vector<1x1x4x1xf32>
// CHECK: %[[B_OUT:.+]] = vector.from_elements
// CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3
// CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x2x1x1x4x1xf32> -> vector<32x32xf32>
// CHECK: return %[[R_SIMD]]
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,12 @@ builtin.module attributes { transform.with_named_sequence } {
}

// CHECK-LABEL: func @inter_subgroup_reduction
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32>
// Local reduction
// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<2x1x1x1x1x4xf32> to vector<2x1x1xf32>
// Thread reduction
// CHECK: %[[THREAD_RED0:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// CHECK: %[[THREAD_RED1:.+]] = vector.insert %[[THREAD_RED0]], %[[CST1]] [0] : f32 into vector<2xf32>
// CHECK: %[[THREAD_RED2:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
// CHECK: %[[THREAD_RED3:.+]] = vector.insert %[[THREAD_RED2]], %[[THREAD_RED1]] [1] : f32 into vector<2xf32>
// CHECK: %[[THREAD_RED3:.+]] = vector.from_elements %[[THREAD_RED0]], %[[THREAD_RED2]] : vector<2xf32>
// CHECK: %[[THREAD_RED4:.+]] = vector.shape_cast %[[THREAD_RED3]] : vector<2xf32> to vector<2x1x1xf32>
// Subgroup reduction
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x2xf32, #gpu.address_space<workgroup>>
Expand All @@ -177,11 +175,10 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
// CHECK-DAG: %[[DISTR0:.+]] = vector.extract %[[SG_READ0]][0, 0] : f32 from vector<1x1xf32>
// CHECK-DAG: %[[RED0:.+]] = gpu.subgroup_reduce maximumf %[[DISTR0]] cluster(size = 2, stride = 16) : (f32) -> f32
// CHECK-DAG: %[[INS0:.+]] = vector.insert %[[RED0]], %[[CST1]] [0] : f32 into vector<2xf32>
// CHECK-DAG: %[[DISTR1:.+]] = vector.extract %[[SG_READ1]][0, 0] : f32 from vector<1x1xf32>
// CHECK-DAG: %[[RED1:.+]] = gpu.subgroup_reduce maximumf %[[DISTR1]] cluster(size = 2, stride = 16) : (f32) -> f32
// CHECK-DAG: %[[INS1:.+]] = vector.insert %[[RED1]], %[[INS0]] [1] : f32 into vector<2xf32>
// CHECK-DAG: %[[CAST:.+]] = vector.shape_cast %[[INS1]] : vector<2xf32> to vector<2x1x1xf32>
// CHECK-DAG: %[[INS:.+]] = vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32>
// CHECK-DAG: %[[CAST:.+]] = vector.shape_cast %[[INS]] : vector<2xf32> to vector<2x1x1xf32>
// CHECK-DAG: arith.maximumf %[[CAST]], %[[ACC]] : vector<2x1x1xf32>

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ void ConvertToLLVMPass::runOnOperation() {
patterns, /*force32BitVectorIndices=*/false);
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorFromElementsLoweringPatterns(patterns);
// TODO: doubtful that the "default" does what one want here, it is likely
// better to use shuffle.
vector::populateVectorTransposeLoweringPatterns(
Expand Down Expand Up @@ -1079,6 +1080,7 @@ void ConvertToLLVMPass::runOnOperation() {
vector::populateVectorStepLoweringPatterns(patterns);
populateVectorToLLVMConversionPatterns(typeConverter, patterns,
reassociateFpReductions);
vector::populateVectorFromElementsLoweringPatterns(patterns);
ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
vector::populateVectorTransferLoweringPatterns(patterns,
/*maxTransferRank=*/1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ module {

// CHECK-LABEL: func.func @aligned_generic_pack
// CHECK: %[[IN_0:.+]] = vector.broadcast %{{.+}} : vector<16xf32> to vector<16x16xf32>
// CHECK-COUNT-15: %{{.+}} = vector.insert {{.+}} : vector<16xf32> into vector<16x16xf32>
// CHECK: %[[IN_1:.+]] = vector.insert {{.+}} : vector<16xf32> into vector<16x16xf32>
// CHECK-COUNT-16: %{{.+}} = vector.to_elements {{.+}} : vector<16xf32>
// CHECK: %[[IN_1:.+]] = vector.from_elements {{.+}} : vector<16x16xf32>
// CHECK: %[[T0:.+]] = arith.addf %[[IN_0]], %[[IN_1]] : vector<16x16xf32>
// CHECK: %[[T1:.+]] = arith.minimumf %[[T0]], %{{.+}} : vector<16x16xf32>
// CHECK: %[[T2:.+]] = arith.maximumf %[[T1]], %{{.+}} : vector<16x16xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,5 @@ func.func @split_reduction_double_reduction_unsupported() attributes {hal.execut
}

// CHECK-LABEL: func.func @split_reduction_double_reduction_unsupported()
// CHECK: vector.insert %{{.+}}, %{{.+}} : i32 into vector<4xi32>
// CHECK: vector.from_elements %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : vector<4xi32>
// CHECK-NOT: vector.insert %{{.+}}, %{{.+}} : i32 into vector<1xi32>
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ struct ConvertToNVVMPass final
patterns, options.vectorContractLowering);
vector::populateVectorGatherLoweringPatterns(patterns);
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorFromElementsLoweringPatterns(patterns);
// We currently always use 64 bit indices, thus ensure the bit width of
// the mask compare is consistent.
vector::populateVectorMaskMaterializationPatterns(
Expand Down
22 changes: 22 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ static LogicalResult validateDataTypes(Operation *op,
return success();
}

/// TODO(hanchung): Delete the pattern once it is upstreamed:
/// https://github.com/llvm/llvm-project/pull/156992
struct LowerToElementsPattern : public OpRewritePattern<vector::ToElementsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ToElementsOp op,
PatternRewriter &rewriter) const override {
VectorType vecType = op.getSource().getType();
if (vecType.getRank() == 1 || vecType.getNumScalableDims() > 0) {
return failure();
}
auto vec1DType =
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
Value shapeCast = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), vec1DType, op.getSource());
Comment on lines +187 to +188
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if flattening was the right call, this shouldve been an unrolling pattern. Everything else does unrolling.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's made explicit this is mostly a stopgap patch

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough

rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
shapeCast);
return success();
}
};

/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding ROCDL equivalent.
///
Expand Down Expand Up @@ -256,6 +276,7 @@ struct ConvertToROCDLPass final
vector::populateVectorInterleaveToShufflePatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns, options.vectorContractLowering);
vector::populateVectorFromElementsLoweringPatterns(patterns);
vector::populateVectorGatherLoweringPatterns(patterns);
vector::populateVectorMaskOpLoweringPatterns(patterns);
// We currently always use 64 bit indices, thus ensure the bit width of
Expand All @@ -269,6 +290,7 @@ struct ConvertToROCDLPass final
patterns, options.vectorTransposeLowering);
vector::populateVectorTransferLoweringPatterns(patterns);
arith::populateExpandBFloat16Patterns(patterns);
patterns.insert<LowerToElementsPattern>(&getContext());
if (failed(applyPatternsGreedily(m, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK-LABEL: func @extract_strided_slice_8_elements
func.func @extract_strided_slice_8_elements(%input: vector<8xf16>) -> vector<4xf16> {
// CHECK-COUNT-4: vector.extract
// CHECK-COUNT-4: vector.insert
// CHECK: vector.from_elements
%0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
return %0: vector<4xf16>
}
Expand All @@ -22,9 +22,8 @@ func.func @extract_strided_slice_4_elements(%input: vector<4xf16>) -> vector<2xf
// CHECK-LABEL: func @bitcast_16_elements
func.func @bitcast_16_elements(%input: vector<16xi8>) -> vector<4xi32> {
// CHECK-DAG: %[[CST_I32:.*]] = arith.constant dense<0> : vector<4xi32>
// CHECK-DAG: arith.constant dense<0> : vector<4xi8>
// CHECK-COUNT-4: vector.extract
// CHECK-COUNT-4: vector.insert
// CHECK: vector.from_elements
// CHECK: vector.bitcast %{{.*}} : vector<4xi8> to vector<1xi32>
// CHECK: vector.insert_strided_slice {{.*}}, %[[CST_I32]]
// CHECK-COUNT-3: vector.bitcast
Expand All @@ -41,28 +40,22 @@ func.func @bitcast_extract_extend_0(%input: vector<1xi32>) -> vector<4xi32> {
return %extend : vector<4xi32>
}


// CHECK-LABEL: func @bitcast_extract_extend_0
// CHECK-SAME: (%[[INPUT:.+]]: vector<1xi32>)
// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<4xi32>
// CHECK-DAG: %[[MASK:.+]] = arith.constant 15 : i32
// CHECK-DAG: %[[OFF1:.+]] = arith.constant 4 : i32
// CHECK-DAG: %[[OFF2:.+]] = arith.constant 8 : i32
// CHECK-DAG: %[[OFF3:.+]] = arith.constant 12 : i32
// CHECK: %[[BASE:.+]] = vector.extract %[[INPUT]][0] : i32 from vector<1xi32>
// CHECK: %[[AND0:.+]] = arith.andi %[[BASE]], %[[MASK]] : i32
// CHECK: %[[INS0:.+]] = vector.insert %[[AND0]], %[[ZERO]] [0]
// CHECK: %[[SHR1:.+]] = arith.shrui %[[BASE]], %[[OFF1]] : i32
// CHECK: %[[AND1:.+]] = arith.andi %[[SHR1]], %[[MASK]] : i32
// CHECK: %[[INS1:.+]] = vector.insert %[[AND1]], %[[INS0]] [1]
// CHECK: %[[SHR2:.+]] = arith.shrui %[[BASE]], %[[OFF2]] : i32
// CHECK: %[[AND2:.+]] = arith.andi %[[SHR2]], %[[MASK]] : i32
// CHECK: %[[INS2:.+]] = vector.insert %[[AND2]], %[[INS1]] [2]
// CHECK: %[[SHR3:.+]] = arith.shrui %[[BASE]], %[[OFF3]] : i32
// CHECK: %[[AND3:.+]] = arith.andi %[[SHR3]], %[[MASK]] : i32
// CHECK: %[[INS3:.+]] = vector.insert %[[AND3]], %[[INS2]] [3]
// CHECK: return %[[INS3]] : vector<4xi32>

// CHECK: %[[RES:.+]] = vector.from_elements %[[AND0]], %[[AND1]], %[[AND2]], %[[AND3]] : vector<4xi32>
// CHECK: return %[[RES]] : vector<4xi32>

// -----

Expand All @@ -75,7 +68,6 @@ func.func @bitcast_extract_extend_1(%input: vector<4xi32>) -> vector<4xi32> {

// CHECK-LABEL: func.func @bitcast_extract_extend_1
// CHECK-SAME: (%[[INPUT:.+]]: vector<4xi32>)
// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<4xi32>
// CHECK-DAG: %[[MASK:.+]] = arith.constant 15 : i32
// CHECK-DAG: %[[OFF0:.+]] = arith.constant 16 : i32
// CHECK-DAG: %[[OFF1:.+]] = arith.constant 20 : i32
Expand All @@ -84,14 +76,11 @@ func.func @bitcast_extract_extend_1(%input: vector<4xi32>) -> vector<4xi32> {
// CHECK: %[[BASE:.+]] = vector.extract %[[INPUT]][2] : i32 from vector<4xi32>
// CHECK: %[[SHR0:.+]] = arith.shrui %[[BASE]], %[[OFF0]] : i32
// CHECK: %[[AND0:.+]] = arith.andi %[[SHR0]], %[[MASK]] : i32
// CHECK: %[[INS0:.+]] = vector.insert %[[AND0]], %[[ZERO]] [0]
// CHECK: %[[SHR1:.+]] = arith.shrui %[[BASE]], %[[OFF1]] : i32
// CHECK: %[[AND1:.+]] = arith.andi %[[SHR1]], %[[MASK]] : i32
// CHECK: %[[INS1:.+]] = vector.insert %[[AND1]], %[[INS0]] [1]
// CHECK: %[[SHR2:.+]] = arith.shrui %[[BASE]], %[[OFF2]] : i32
// CHECK: %[[AND2:.+]] = arith.andi %[[SHR2]], %[[MASK]] : i32
// CHECK: %[[INS2:.+]] = vector.insert %[[AND2]], %[[INS1]] [2]
// CHECK: %[[SHR3:.+]] = arith.shrui %[[BASE]], %[[OFF3]] : i32
// CHECK: %[[AND3:.+]] = arith.andi %[[SHR3]], %[[MASK]] : i32
// CHECK: %[[INS3:.+]] = vector.insert %[[AND3]], %[[INS2]] [3]
// CHECK: return %[[INS3]] : vector<4xi32>
// CHECK: %[[RES:.+]] = vector.from_elements %[[AND0]], %[[AND1]], %[[AND2]], %[[AND3]] : vector<4xi32>
// CHECK: return %[[RES]] : vector<4xi32>
Loading
Loading