diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 6309c37788f0c..999726340aaed 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -2216,7 +2216,7 @@ llvm::Value *CodeGenFunction::EmitToMemory(llvm::Value *Value, QualType Ty) { if (auto *AtomicTy = Ty->getAs()) Ty = AtomicTy->getValueType(); - if (Ty->isExtVectorBoolType()) { + if (Ty->isExtVectorBoolType() || Ty->isConstantMatrixBoolType()) { llvm::Type *StoreTy = convertTypeForLoadStore(Ty, Value->getType()); if (StoreTy->isVectorTy() && StoreTy->getScalarSizeInBits() > Value->getType()->getScalarSizeInBits()) diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 4239552d1299e..0e1131d586433 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -107,8 +107,7 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) { llvm::Type *IRElemTy = ConvertType(MT->getElementType()); if (Context.getLangOpts().HLSL && T->isConstantMatrixBoolType()) IRElemTy = ConvertTypeForMem(Context.BoolTy); - return llvm::ArrayType::get(IRElemTy, - MT->getNumRows() * MT->getNumColumns()); + return llvm::ArrayType::get(IRElemTy, MT->getNumElementsFlattened()); } llvm::Type *R = ConvertType(T); @@ -180,6 +179,16 @@ llvm::Type *CodeGenTypes::convertTypeForLoadStore(QualType T, return llvm::IntegerType::get(getLLVMContext(), (unsigned)Context.getTypeSize(T)); + if (T->isConstantMatrixBoolType()) { + // Matrices are loaded and stored atomically as vectors. Therefore we + // construct a FixedVectorType here instead of returning + // ConvertTypeForMem(T) which would return an ArrayType instead. + const Type *Ty = Context.getCanonicalType(T).getTypePtr(); + const ConstantMatrixType *MT = cast(Ty); + llvm::Type *IRElemTy = ConvertTypeForMem(MT->getElementType()); + return llvm::FixedVectorType::get(IRElemTy, MT->getNumElementsFlattened()); + } + if (T->isExtVectorBoolType()) return ConvertTypeForMem(T); diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl index 71186f775b241..05c9ad4b926e6 100644 --- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl +++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl @@ -12,7 +12,7 @@ struct S { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[RETVAL:%.*]] = alloca i1, align 4 // CHECK-NEXT: [[B:%.*]] = alloca [4 x i32], align 4 -// CHECK-NEXT: store <4 x i1> splat (i1 true), ptr [[B]], align 4 +// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[B]], align 4 // CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[B]], align 4 // CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0 // CHECK-NEXT: store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4 @@ -40,11 +40,12 @@ bool fn1() { // CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 [[LOADEDV1]], i32 1 // CHECK-NEXT: [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT2]], i1 true, i32 2 // CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 false, i32 3 -// CHECK-NEXT: store <4 x i1> [[VECINIT4]], ptr [[A]], align 4 -// CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr [[A]], align 4 -// CHECK-NEXT: store <4 x i32> [[TMP2]], ptr [[RETVAL]], align 4 -// CHECK-NEXT: [[TMP3:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4 -// CHECK-NEXT: ret <4 x i1> [[TMP3]] +// CHECK-NEXT: [[TMP2:%.*]] = zext <4 x i1> [[VECINIT4]] to <4 x i32> +// CHECK-NEXT: store <4 x i32> [[TMP2]], ptr [[A]], align 4 +// CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[A]], align 4 +// CHECK-NEXT: store <4 x i32> [[TMP3]], ptr [[RETVAL]], align 4 +// CHECK-NEXT: [[TMP4:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4 +// CHECK-NEXT: ret <4 x i1> [[TMP4]] // bool2x2 fn2(bool V) { bool2x2 A = {V, true, V, false}; @@ -57,7 +58,7 @@ bool2x2 fn2(bool V) { // CHECK-NEXT: [[RETVAL:%.*]] = alloca i1, align 4 // CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1 // CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: store <4 x i1> , ptr [[BM]], align 1 +// CHECK-NEXT: store <4 x i32> , ptr [[BM]], align 1 // CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1 // CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1 // CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 @@ -77,9 +78,9 @@ bool fn3() { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[RETVAL:%.*]] = alloca i1, align 4 // CHECK-NEXT: [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4 -// CHECK-NEXT: store <4 x i1> splat (i1 true), ptr [[ARR]], align 4 +// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[ARR]], align 4 // CHECK-NEXT: [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1 -// CHECK-NEXT: store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4 +// CHECK-NEXT: store <4 x i32> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4 // CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0 // CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4 // CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 1 @@ -96,7 +97,7 @@ bool fn4() { // CHECK-SAME: ) #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[M:%.*]] = alloca [4 x i32], align 4 -// CHECK-NEXT: store <4 x i1> splat (i1 true), ptr [[M]], align 4 +// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[M]], align 4 // CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 3 // CHECK-NEXT: store <4 x i32> [[MATINS]], ptr [[M]], align 4 @@ -114,7 +115,7 @@ void fn5() { // CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1 // CHECK-NEXT: store i32 0, ptr [[V]], align 4 // CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: store <4 x i1> , ptr [[BM]], align 1 +// CHECK-NEXT: store <4 x i32> , ptr [[BM]], align 1 // CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1 // CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1 // CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[V]], align 4 @@ -136,9 +137,9 @@ void fn6() { // CHECK-SAME: ) #[[ATTR0]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4 -// CHECK-NEXT: store <4 x i1> splat (i1 true), ptr [[ARR]], align 4 +// CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[ARR]], align 4 // CHECK-NEXT: [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1 -// CHECK-NEXT: store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4 +// CHECK-NEXT: store <4 x i32> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4 // CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0 // CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4 // CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 1 @@ -149,3 +150,19 @@ void fn7() { bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}}; Arr[0][1][0] = false; } + +// CHECK-LABEL: define hidden noundef <16 x i1> @_Z3fn8u11matrix_typeILm4ELm4EbE( +// CHECK-SAME: <16 x i1> noundef [[M:%.*]]) #[[ATTR0]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[RETVAL:%.*]] = alloca <16 x i1>, align 4 +// CHECK-NEXT: [[M_ADDR:%.*]] = alloca [16 x i32], align 4 +// CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i1> [[M]] to <16 x i32> +// CHECK-NEXT: store <16 x i32> [[TMP0]], ptr [[M_ADDR]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr [[M_ADDR]], align 4 +// CHECK-NEXT: store <16 x i32> [[TMP1]], ptr [[RETVAL]], align 4 +// CHECK-NEXT: [[TMP2:%.*]] = load <16 x i1>, ptr [[RETVAL]], align 4 +// CHECK-NEXT: ret <16 x i1> [[TMP2]] +// +bool4x4 fn8(bool4x4 m) { + return m; +}