[CIR] Upstream extract op for VectorType#138413
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds local zero initialization for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/138413.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 422c89c4f9391..b2121dee8d8b3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1976,4 +1976,28 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// VecExtractOp
+//===----------------------------------------------------------------------===//
+
+def VecExtractOp : CIR_Op<"vec.extract", [Pure,
+ TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
+ "result", "cast<VectorType>($_self).getElementType()">]> {
+
+ let summary = "Extract one element from a vector object";
+ let description = [{
+ The `cir.vec.extract` operation extracts the element at the given index
+ from a vector object.
+ }];
+
+ let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
+ let results = (outs CIR_AnyType:$result);
+
+ let assemblyFormat = [{
+ $vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
+ }];
+
+ let hasVerifier = 0;
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8ead6e793b4c8..a59b87cb9241b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -161,8 +161,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
if (e->getBase()->getType()->isVectorType()) {
assert(!cir::MissingFeatures::scalableVectors());
- cgf.getCIRGenModule().errorNYI("VisitArraySubscriptExpr: VectorType");
- return {};
+
+ const mlir::Location loc = cgf.getLoc(e->getSourceRange());
+ const mlir::Value vecValue = Visit(e->getBase());
+ const mlir::Value indexValue = Visit(e->getIdx());
+ return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue);
}
// Just load the lvalue formed by the subscript expression.
return emitLoadOfLValue(e);
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 6137adb1e9936..66f29f8f6cdd0 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1600,7 +1600,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMStackRestoreOpLowering,
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering,
- CIRToLLVMVecCreateOpLowering
+ CIRToLLVMVecCreateOpLowering,
+ CIRToLLVMVecExtractOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1709,6 +1710,14 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
+ cir::VecExtractOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
+ op, adaptor.getVec(), adaptor.getIndex());
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index f248ea31e7844..026505ea31b4c 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -303,6 +303,16 @@ class CIRToLLVMVecCreateOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecExtractOpLowering
+ : public mlir::OpConversionPattern<cir::VecExtractOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecExtractOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecExtractOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a3880a944de1f..aeeaf655cad18 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -109,3 +109,36 @@ void foo2(vi4 p) {}
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+ vi4 a = { 1, 2, 3, 4 };
+ int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 76a85eab52380..9c85ed4a9e216 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -96,3 +96,36 @@ void foo2(vi4 p) {}
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+ vi4 a = { 1, 2, 3, 4 };
+ int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index d2612a7310ad0..aeb268e84c71c 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -65,4 +65,36 @@ cir.func @local_vector_create_test() {
// CHECK: cir.return
// CHECK: }
+cir.func @vector_extract_element_test() {
+ %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+ %2 = cir.const #cir.int<1> : !s32i
+ %3 = cir.const #cir.int<2> : !s32i
+ %4 = cir.const #cir.int<3> : !s32i
+ %5 = cir.const #cir.int<4> : !s32i
+ %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+ cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+ %8 = cir.const #cir.int<1> : !s32i
+ %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+ cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+ cir.return
+}
+
+// CHECK: cir.func @vector_extract_element_test() {
+// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CHECK: %2 = cir.const #cir.int<1> : !s32i
+// CHECK: %3 = cir.const #cir.int<2> : !s32i
+// CHECK: %4 = cir.const #cir.int<3> : !s32i
+// CHECK: %5 = cir.const #cir.int<4> : !s32i
+// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK: %8 = cir.const #cir.int<1> : !s32i
+// CHECK: %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+// CHECK: cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK: cir.return
+// CHECK: }
+
}
|
| $vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec)) | ||
| }]; | ||
|
|
||
| let hasVerifier = 0; |
There was a problem hiding this comment.
nit: this is implicit
| let hasVerifier = 0; |
| let summary = "Extract one element from a vector object"; | ||
| let description = [{ | ||
| The `cir.vec.extract` operation extracts the element at the given index | ||
| from a vector object. |
There was a problem hiding this comment.
nit: can we add some example here?
72b4aac to
80bd84c
Compare
bcardosolopes
left a comment
There was a problem hiding this comment.
Overall looks good. While here, can you please implement a folder for this operation? It should kick-in if both idx and input vector are constants.
@bcardosolopes I implement it like this snippet But I am thinking, is there a case that codegen will perform extractOp directly from ConstVec, not on load or get_global? I see a similar implementation in MLIR Vector Dialect 🤔 I will try to come up with a test case for testing |
andykaylor
left a comment
There was a problem hiding this comment.
Looks good to me, with a small request.
| void foo3() { | ||
| vi4 a = { 1, 2, 3, 4 }; | ||
| int e = a[1]; | ||
| } |
There was a problem hiding this comment.
Can you add a test where the index of the element being extracted is a variable?
I think no need for dyn_casts here, those gets should already reaturn |
As far as I understood, they will not always return ConstVecAttr and IntAttr because if an index or a vector is coming from parameters, or other variables, they are not constants in that case, and I am thinking even if the vector is constant, we will use load or get_global when we use it in Extract and will not pass it directly like const int 🤔, not sure if there is a case that will trigger that fold |
Sorry my bad, you are right :) |
|
@bcardosolopes, I think we can finish this PR, and I will continue searching for a test case for folding without blocking upstreaming other Vector Ops. Also, I think the same idea can be applied to Arrays, what do you think? |
no problemo, thanks! |
bcardosolopes
left a comment
There was a problem hiding this comment.
LGTM, folder is going to be implemented in follow up PR
* main: (420 commits) [AArch64] Merge scaled and unscaled narrow zero stores (llvm#136705) [RISCV] One last migration to getInsertSubvector [nfc] [flang][OpenMP] Update `do concurrent` mapping pass to use `fir.do_concurrent` op (llvm#138489) [MLIR][LLVM] Fix llvm.mlir.global mismatching print and parser order (llvm#138986) [lld][NFC] Fix minor typo in docs (llvm#138898) [RISCV] Migrate getConstant indexed insert/extract subvector to new API (llvm#139111) GlobalISel: Translate minimumnum and maximumnum (llvm#139106) [MemProf] Simplify unittest save and restore of options (llvm#139117) [BOLT][AArch64] Patch functions targeted by optional relocs (llvm#138750) [Coverage] Support -fprofile-list for cold function coverage (llvm#136333) Remove unused forward decl (llvm#139108) [AMDGPU][NFC] Get rid of OPW constants. (llvm#139074) [CIR] Upstream extract op for VectorType (llvm#138413) [mlir][xegpu] Handle scalar uniform ops in SIMT distribution. (llvm#138593) [GlobalISel][AMDGPU] Fix handling of v2i128 type for AND, OR, XOR (llvm#138574) AMDGPU][True16][CodeGen] FP_Round f64 to f16 in true16 (llvm#128911) Reland [Clang] Deprecate `__is_trivially_relocatable` (llvm#139061) [HLSL][NFC] Stricter Overload Tests (clamp,max,min,pow) (llvm#138993) [MLIR] Fixing the memref linearization size computation for non-packed memref (llvm#138922) [TableGen][NFC] Use early exit to simplify large block in emitAction. (llvm#138220) ...
This change adds extract op for VectorType
Issue #136487