diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index e899db84bb615..89b4d25b37ba6 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -341,6 +341,44 @@ def CIR_ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", [ let genVerifyDecl = 1; } +//===----------------------------------------------------------------------===// +// ConstRecordAttr +//===----------------------------------------------------------------------===// + +def CIR_ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record", [ + TypedAttrInterface +]> { + let summary = "Represents a constant record"; + let description = [{ + Effectively supports "struct-like" constants. It's must be built from + an `mlir::ArrayAttr` instance where each element is a typed attribute + (`mlir::TypedAttribute`). + + Example: + ``` + cir.global external @rgb2 = #cir.const_record<{0 : i8, + 5 : i64, #cir.null : !cir.ptr + }> : !cir.record<"", i8, i64, !cir.ptr> + ``` + }]; + + let parameters = (ins AttributeSelfTypeParameter<"">:$type, + "mlir::ArrayAttr":$members); + + let builders = [ + AttrBuilderWithInferredContext<(ins "cir::RecordType":$type, + "mlir::ArrayAttr":$members), [{ + return $_get(type.getContext(), type, members); + }]> + ]; + + let assemblyFormat = [{ + `<` custom($members) `>` + }]; + + let genVerifyDecl = 1; +} + //===----------------------------------------------------------------------===// // ConstPtrAttr //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index c1088c4cd0821..724b46d7c9719 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -60,6 +60,23 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { trailingZerosNum); } + cir::ConstRecordAttr getAnonConstRecord(mlir::ArrayAttr arrayAttr, + bool packed = false, + bool padded = false, + mlir::Type ty = {}) { + llvm::SmallVector members; + for (auto &f : arrayAttr) { + auto ta = mlir::cast(f); + members.push_back(ta.getType()); + } + + if (!ty) + ty = getAnonRecordTy(members, packed, padded); + + auto sTy = mlir::cast(ty); + return cir::ConstRecordAttr::get(sTy, arrayAttr); + } + std::string getUniqueAnonRecordName() { return getUniqueRecordName("anon"); } std::string getUniqueRecordName(const std::string &baseName) { diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index 2fbf69d5d01f4..07a0e186d0cd2 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -285,7 +285,7 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType, mlir::Type commonElementType, unsigned arrayBound, SmallVectorImpl &elements, mlir::TypedAttr filler) { - const CIRGenBuilderTy &builder = cgm.getBuilder(); + CIRGenBuilderTy &builder = cgm.getBuilder(); unsigned nonzeroLength = arrayBound; if (elements.size() < nonzeroLength && builder.isNullValue(filler)) @@ -306,6 +306,33 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType, if (trailingZeroes >= 8) { assert(elements.size() >= nonzeroLength && "missing initializer for non-zero element"); + + if (commonElementType && nonzeroLength >= 8) { + // If all the elements had the same type up to the trailing zeroes and + // there are eight or more nonzero elements, emit a struct of two arrays + // (the nonzero data and the zeroinitializer). + SmallVector eles; + eles.reserve(nonzeroLength); + for (const auto &element : elements) + eles.push_back(element); + auto initial = cir::ConstArrayAttr::get( + cir::ArrayType::get(commonElementType, nonzeroLength), + mlir::ArrayAttr::get(builder.getContext(), eles)); + elements.resize(2); + elements[0] = initial; + } else { + // Otherwise, emit a struct with individual elements for each nonzero + // initializer, followed by a zeroinitializer array filler. + elements.resize(nonzeroLength + 1); + } + + mlir::Type fillerType = + commonElementType + ? commonElementType + : mlir::cast(desiredType).getElementType(); + fillerType = cir::ArrayType::get(fillerType, trailingZeroes); + elements.back() = cir::ZeroAttr::get(fillerType); + commonElementType = nullptr; } else if (elements.size() != arrayBound) { elements.resize(arrayBound, filler); @@ -325,8 +352,13 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType, mlir::ArrayAttr::get(builder.getContext(), eles)); } - cgm.errorNYI("array with different type elements"); - return {}; + SmallVector eles; + eles.reserve(elements.size()); + for (auto const &element : elements) + eles.push_back(element); + + auto arrAttr = mlir::ArrayAttr::get(builder.getContext(), eles); + return builder.getAnonConstRecord(arrAttr, /*isPacked=*/true); } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index c039bdcd7f6a1..5f53a6335f37d 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -15,6 +15,14 @@ #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +//===-----------------------------------------------------------------===// +// RecordMembers +//===-----------------------------------------------------------------===// + +static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members); +static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser, + mlir::ArrayAttr &members); + //===-----------------------------------------------------------------===// // IntLiteral //===-----------------------------------------------------------------===// @@ -68,6 +76,61 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { llvm_unreachable("unexpected CIR type kind"); } +static void printRecordMembers(mlir::AsmPrinter &printer, + mlir::ArrayAttr members) { + printer << '{'; + llvm::interleaveComma(members, printer); + printer << '}'; +} + +static ParseResult parseRecordMembers(mlir::AsmParser &parser, + mlir::ArrayAttr &members) { + llvm::SmallVector elts; + + auto delimiter = AsmParser::Delimiter::Braces; + auto result = parser.parseCommaSeparatedList(delimiter, [&]() { + mlir::TypedAttr attr; + if (parser.parseAttribute(attr).failed()) + return mlir::failure(); + elts.push_back(attr); + return mlir::success(); + }); + + if (result.failed()) + return mlir::failure(); + + members = mlir::ArrayAttr::get(parser.getContext(), elts); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// ConstRecordAttr definitions +//===----------------------------------------------------------------------===// + +LogicalResult +ConstRecordAttr::verify(function_ref emitError, + mlir::Type type, ArrayAttr members) { + auto sTy = mlir::dyn_cast_if_present(type); + if (!sTy) + return emitError() << "expected !cir.record type"; + + if (sTy.getMembers().size() != members.size()) + return emitError() << "number of elements must match"; + + unsigned attrIdx = 0; + for (auto &member : sTy.getMembers()) { + auto m = mlir::cast(members[attrIdx]); + if (member != m.getType()) + return emitError() << "element at index " << attrIdx << " has type " + << m.getType() + << " but the expected type for this element is " + << member; + attrIdx++; + } + + return success(); +} + //===----------------------------------------------------------------------===// // OptInfoAttr definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 220927601f74e..72feee8709dc4 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -341,8 +341,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, } if (mlir::isa( - attrType)) + cir::ConstComplexAttr, cir::ConstRecordAttr, + cir::GlobalViewAttr, cir::PoisonAttr>(attrType)) return success(); assert(isa(attrType) && "What else could we be looking at here?"); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9ab7178e9ab12..badd6de814bd7 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -201,8 +201,8 @@ class CIRAttrToValue { mlir::Value visit(mlir::Attribute attr) { return llvm::TypeSwitch(attr) .Case( + cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr, + cir::ConstPtrAttr, cir::GlobalViewAttr, cir::ZeroAttr>( [&](auto attrT) { return visitCirAttr(attrT); }) .Default([&](auto attrT) { return mlir::Value(); }); } @@ -212,6 +212,7 @@ class CIRAttrToValue { mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr); mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); mlir::Value visitCirAttr(cir::ConstArrayAttr attr); + mlir::Value visitCirAttr(cir::ConstRecordAttr attr); mlir::Value visitCirAttr(cir::ConstVectorAttr attr); mlir::Value visitCirAttr(cir::GlobalViewAttr attr); mlir::Value visitCirAttr(cir::ZeroAttr attr); @@ -386,6 +387,21 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { return result; } +/// ConstRecord visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstRecordAttr constRecord) { + const mlir::Type llvmTy = converter->convertType(constRecord.getType()); + const mlir::Location loc = parentOp->getLoc(); + mlir::Value result = rewriter.create(loc, llvmTy); + + // Iteratively lower each constant element of the record. + for (auto [idx, elt] : llvm::enumerate(constRecord.getMembers())) { + mlir::Value init = visit(elt); + result = rewriter.create(loc, result, init, idx); + } + + return result; +} + /// ConstVectorAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { const mlir::Type llvmTy = converter->convertType(attr.getType()); @@ -1286,6 +1302,11 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( rewriter.eraseOp(op); return mlir::success(); } + } else if (const auto recordAttr = + mlir::dyn_cast(op.getValue())) { + auto initVal = lowerCirAttrAsValue(op, recordAttr, rewriter, typeConverter); + rewriter.replaceOp(op, initVal); + return mlir::success(); } else if (const auto vecTy = mlir::dyn_cast(op.getType())) { rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, getTypeConverter())); @@ -1527,9 +1548,9 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( cir::GlobalOp op, mlir::Attribute init, mlir::ConversionPatternRewriter &rewriter) const { // TODO: Generalize this handling when more types are needed here. - assert( - (isa(init))); + assert((isa(init))); // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals @@ -1582,8 +1603,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( return mlir::failure(); } } else if (mlir::isa(init.value())) { + cir::ConstRecordAttr, cir::ConstPtrAttr, + cir::ConstComplexAttr, cir::GlobalViewAttr, + cir::ZeroAttr>(init.value())) { // TODO(cir): once LLVM's dialect has proper equivalent attributes this // should be updated. For now, we use a custom op to initialize globals // to the appropriate value. diff --git a/clang/test/CIR/CodeGen/array.cpp b/clang/test/CIR/CodeGen/array.cpp index 60028af4b3161..a643de2d26189 100644 --- a/clang/test/CIR/CodeGen/array.cpp +++ b/clang/test/CIR/CodeGen/array.cpp @@ -45,9 +45,9 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}}; // OGCG: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]] int e[10] = {1, 2}; -// CIR: cir.global external @e = #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i], trailing_zeros> : !cir.array +// CIR: cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array}> : !rec_anon_struct -// LLVM: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0] +// LLVM: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }> // OGCG: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }> @@ -58,6 +58,28 @@ int f[5] = {1, 2}; // OGCG: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0] +int g[16] = {1, 2, 3, 4, 5, 6, 7, 8}; +// CIR: cir.global external @g = #cir.const_record<{ +// CIR-SAME: #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, +// CIR-SAME: #cir.int<3> : !s32i, #cir.int<4> : !s32i, +// CIR-SAME: #cir.int<5> : !s32i, #cir.int<6> : !s32i, +// CIR-SAME: #cir.int<7> : !s32i, #cir.int<8> : !s32i]> +// CIR-SAME: : !cir.array, +// CIR-SAME: #cir.zero : !cir.array}> : !rec_anon_struct1 + +// LLVM: @g = global <{ [8 x i32], [8 x i32] }> +// LLVM-SAME: <{ [8 x i32] +// LLVM-SAME: [i32 1, i32 2, i32 3, i32 4, +// LLVM-SAME: i32 5, i32 6, i32 7, i32 8], +// LLVM-SAME: [8 x i32] zeroinitializer }> + +// OGCG: @g = global <{ [8 x i32], [8 x i32] }> +// OGCG-SAME: <{ [8 x i32] +// OGCG-SAME: [i32 1, i32 2, i32 3, i32 4, +// OGCG-SAME: i32 5, i32 6, i32 7, i32 8], +// OGCG-SAME: [8 x i32] zeroinitializer }> + + extern int b[10]; // CIR: cir.global "private" external @b : !cir.array // LLVM: @b = external global [10 x i32] diff --git a/clang/test/CIR/IR/invalid-const-record.cir b/clang/test/CIR/IR/invalid-const-record.cir new file mode 100644 index 0000000000000..37d7789d4bd45 --- /dev/null +++ b/clang/test/CIR/IR/invalid-const-record.cir @@ -0,0 +1,23 @@ +// RUN: cir-opt %s -verify-diagnostics -split-input-file + +!s32i = !cir.int +!rec_anon_struct = !cir.record}> + +// expected-error @below {{expected !cir.record type}} +cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array}> : !cir.ptr + +// ----- + +!s32i = !cir.int +!rec_anon_struct = !cir.record}> + +// expected-error @below {{number of elements must match}} +cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.zero : !cir.array}> : !rec_anon_struct + +// ----- + +!s32i = !cir.int +!rec_anon_struct = !cir.record}> + +// expected-error @below {{element at index 1 has type '!cir.float' but the expected type for this element is '!cir.int'}} +cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.fp<2.000000e+00> : !cir.float, #cir.zero : !cir.array}> : !rec_anon_struct diff --git a/clang/test/CIR/IR/struct.cir b/clang/test/CIR/IR/struct.cir index 85f475f643ee5..33f2e9860c5cb 100644 --- a/clang/test/CIR/IR/struct.cir +++ b/clang/test/CIR/IR/struct.cir @@ -13,8 +13,9 @@ // CHECK-DAG: !rec_S = !cir.record // CHECK-DAG: !rec_U = !cir.record -!rec_anon_struct = !cir.record x 5>}> -!rec_anon_struct1 = !cir.record, !cir.ptr, !cir.ptr}> +!rec_anon_struct = !cir.record}> +!rec_anon_struct1 = !cir.record x 5>}> +!rec_anon_struct2 = !cir.record, !cir.ptr, !cir.ptr}> !rec_S1 = !cir.record !rec_Sc = !cir.record @@ -42,18 +43,22 @@ !rec_Node = !cir.record>}> // CHECK-DAG: !cir.record>}> + + module { cir.global external @p1 = #cir.ptr : !cir.ptr cir.global external @p2 = #cir.ptr : !cir.ptr cir.global external @p3 = #cir.ptr : !cir.ptr + cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array}> : !rec_anon_struct // CHECK: cir.global external @p1 = #cir.ptr : !cir.ptr // CHECK: cir.global external @p2 = #cir.ptr : !cir.ptr // CHECK: cir.global external @p3 = #cir.ptr : !cir.ptr +// CHECK: cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array}> : !rec_anon_struct // Dummy function to use types and force them to be printed. cir.func @useTypes(%arg0: !rec_Node, - %arg1: !rec_anon_struct1, - %arg2: !rec_anon_struct, + %arg1: !rec_anon_struct2, + %arg2: !rec_anon_struct1, %arg3: !rec_S1, %arg4: !rec_Ac, %arg5: !rec_P1, diff --git a/clang/test/CIR/Lowering/array.cpp b/clang/test/CIR/Lowering/array.cpp index 82d803a6b5aa2..40ad986b7fdfa 100644 --- a/clang/test/CIR/Lowering/array.cpp +++ b/clang/test/CIR/Lowering/array.cpp @@ -19,7 +19,7 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}}; // CHECK: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]] int e[10] = {1, 2}; -// CHECK: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0] +// CHECK: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }> int f[5] = {1, 2}; // CHECK: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0]