diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index 3bde090b88d7a5..78ed00c13041cf 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -33525,28 +33525,63 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) case NI_Vector512_ConditionalSelect: #elif defined(TARGET_ARM64) case NI_AdvSimd_BitwiseSelect: - case NI_Sve_ConditionalSelect: #endif { assert(!varTypeIsMask(retType)); + assert(!varTypeIsMask(op1)); if (cnsNode != op1) { break; } -#if defined(TARGET_ARM64) - if (ni == NI_Sve_ConditionalSelect) + if (op1->IsVectorAllBitsSet()) { - assert(!op1->IsVectorAllBitsSet() && !op1->IsVectorZero()); + if ((op3->gtFlags & GTF_SIDE_EFFECT) != 0) + { + // op3 has side effects, this would require us to append a new statement + // to ensure that it isn't lost, which isn't safe to do from the general + // purpose handler here. We'll recognize this and mark it in VN instead + break; + } + + // op3 has no side effects, so we can return op2 directly + return op2; } - else + + if (op1->IsVectorZero()) { - assert(!op1->IsTrueMask(simdBaseType) && !op1->IsMaskZero()); + return gtWrapWithSideEffects(op3, op2, GTF_ALL_EFFECT); + } + + if (op2->IsCnsVec() && op3->IsCnsVec()) + { + // op2 = op2 & op1 + op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon()); + + // op3 = op2 & ~op1 + op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon()); + + // op2 = op2 | op3 + op2->AsVecCon()->EvaluateBinaryInPlace(GT_OR, false, simdBaseType, op3->AsVecCon()); + + resultNode = op2; + } + break; + } + +#if defined(TARGET_ARM64) + case NI_Sve_ConditionalSelect: + { + assert(!varTypeIsMask(retType)); + assert(varTypeIsMask(op1)); + + if (cnsNode != op1) + { + break; } -#endif - if (op1->IsVectorAllBitsSet() || op1->IsTrueMask(simdBaseType)) + if (op1->IsTrueMask(simdBaseType)) { if ((op3->gtFlags & GTF_SIDE_EFFECT) != 0) { @@ -33560,18 +33595,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) return op2; } - if (op1->IsVectorZero() || op1->IsMaskZero()) + if (op1->IsMaskZero()) { return gtWrapWithSideEffects(op3, op2, GTF_ALL_EFFECT); } if (op2->IsCnsVec() && op3->IsCnsVec()) { + assert(op2->gtType == TYP_SIMD16); + assert(op3->gtType == TYP_SIMD16); + + simd16_t op1SimdVal; + EvaluateSimdCvtMaskToVector(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal); + // op2 = op2 & op1 - op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon()); + simd16_t result = {}; + EvaluateBinarySimd(GT_AND, false, simdBaseType, &result, op2->AsVecCon()->gtSimd16Val, + op1SimdVal); + op2->AsVecCon()->gtSimd16Val = result; // op3 = op2 & ~op1 - op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon()); + result = {}; + EvaluateBinarySimd(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val, + op1SimdVal); + op3->AsVecCon()->gtSimd16Val = result; // op2 = op2 | op3 op2->AsVecCon()->EvaluateBinaryInPlace(GT_OR, false, simdBaseType, op3->AsVecCon()); @@ -33580,6 +33627,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) } break; } +#endif // TARGET_ARM64 default: { diff --git a/src/coreclr/jit/lower.cpp b/src/coreclr/jit/lower.cpp index 9395f1c83889a5..8fe43d05484567 100644 --- a/src/coreclr/jit/lower.cpp +++ b/src/coreclr/jit/lower.cpp @@ -789,6 +789,11 @@ GenTree* Lowering::LowerNode(GenTree* node) LowerReturnSuspend(node); break; +#if defined(FEATURE_HW_INTRINSICS) && defined(TARGET_ARM64) + case GT_CNS_MSK: + return LowerCnsMask(node->AsMskCon()); +#endif // FEATURE_HW_INTRINSICS && TARGET_ARM64 + default: break; } diff --git a/src/coreclr/jit/lower.h b/src/coreclr/jit/lower.h index b012d8cacb269a..49d74503e0593b 100644 --- a/src/coreclr/jit/lower.h +++ b/src/coreclr/jit/lower.h @@ -451,11 +451,12 @@ class Lowering final : public Phase GenTree* TryLowerXorOpToGetMaskUpToLowestSetBit(GenTreeOp* xorNode); void LowerBswapOp(GenTreeOp* node); #elif defined(TARGET_ARM64) - bool IsValidConstForMovImm(GenTreeHWIntrinsic* node); - void LowerHWIntrinsicFusedMultiplyAddScalar(GenTreeHWIntrinsic* node); - void LowerModPow2(GenTree* node); - bool TryLowerAddForPossibleContainment(GenTreeOp* node, GenTree** next); - void StoreFFRValue(GenTreeHWIntrinsic* node); + bool IsValidConstForMovImm(GenTreeHWIntrinsic* node); + void LowerHWIntrinsicFusedMultiplyAddScalar(GenTreeHWIntrinsic* node); + void LowerModPow2(GenTree* node); + GenTree* LowerCnsMask(GenTreeMskCon* mask); + bool TryLowerAddForPossibleContainment(GenTreeOp* node, GenTree** next); + void StoreFFRValue(GenTreeHWIntrinsic* node); #endif // !TARGET_XARCH && !TARGET_ARM64 GenTree* InsertNewSimdCreateScalarUnsafeNode(var_types type, GenTree* op1, diff --git a/src/coreclr/jit/lowerarmarch.cpp b/src/coreclr/jit/lowerarmarch.cpp index c949410a3fdf9f..7a3d1eeb57e12d 100644 --- a/src/coreclr/jit/lowerarmarch.cpp +++ b/src/coreclr/jit/lowerarmarch.cpp @@ -1134,6 +1134,77 @@ void Lowering::LowerModPow2(GenTree* node) ContainCheckNode(mod); } +//------------------------------------------------------------------------ +// LowerCnsMask: Lower GT_CNS_MSK. Ensure the mask matches a known pattern. +// If not then lower to a constant vector. +// +// Arguments: +// mask - the node to lower +// +GenTree* Lowering::LowerCnsMask(GenTreeMskCon* mask) +{ + // Try every type until a match is found + + if (mask->IsZero()) + { + return mask->gtNext; + } + + if (EvaluateSimdMaskToPattern(TYP_BYTE, mask->gtSimdMaskVal) != SveMaskPatternNone) + { + return mask->gtNext; + } + + if (EvaluateSimdMaskToPattern(TYP_SHORT, mask->gtSimdMaskVal) != SveMaskPatternNone) + { + return mask->gtNext; + } + + if (EvaluateSimdMaskToPattern(TYP_INT, mask->gtSimdMaskVal) != SveMaskPatternNone) + { + return mask->gtNext; + } + + if (EvaluateSimdMaskToPattern(TYP_LONG, mask->gtSimdMaskVal) != SveMaskPatternNone) + { + return mask->gtNext; + } + + // Not a valid pattern, so cannot be created using ptrue/pfalse. Instead the mask will require + // loading from memory. There is no way to load to a predicate from memory using a PC relative + // address, so instead use a constant vector plus conversion to mask. Using basetype byte will + // ensure every entry in the mask is converted. + + LABELEDDISPTREERANGE("lowering cns mask to cns vector (before)", BlockRange(), mask); + + // Create a vector constant + GenTreeVecCon* vecCon = comp->gtNewVconNode(TYP_SIMD16); + EvaluateSimdCvtMaskToVector(TYP_BYTE, &vecCon->gtSimdVal, mask->gtSimdMaskVal); + BlockRange().InsertBefore(mask, vecCon); + + // Convert the vector constant to a mask + GenTree* convertedVec = comp->gtNewSimdCvtVectorToMaskNode(TYP_MASK, vecCon, CORINFO_TYPE_BYTE, 16); + BlockRange().InsertBefore(mask, convertedVec->AsHWIntrinsic()->Op(1)); + BlockRange().InsertBefore(mask, convertedVec); + + // Update use + LIR::Use use; + if (BlockRange().TryGetUse(mask, &use)) + { + use.ReplaceWith(convertedVec); + } + else + { + convertedVec->SetUnusedValue(); + } + + BlockRange().Remove(mask); + + LABELEDDISPTREERANGE("lowering cns mask to cns vector (after)", BlockRange(), vecCon); + + return vecCon->gtNext; +} + const int POST_INDEXED_ADDRESSING_MAX_DISTANCE = 16; //------------------------------------------------------------------------ diff --git a/src/coreclr/jit/simd.h b/src/coreclr/jit/simd.h index 9841bdeb38c93c..f6da9993f90d45 100644 --- a/src/coreclr/jit/simd.h +++ b/src/coreclr/jit/simd.h @@ -1598,9 +1598,8 @@ void EvaluateSimdCvtVectorToMask(simdmask_t* result, TSimd arg0) uint32_t count = sizeof(TSimd) / sizeof(TBase); uint64_t mask = 0; - TBase significantBit = 1; #if defined(TARGET_XARCH) - significantBit = static_cast(1) << ((sizeof(TBase) * 8) - 1); + TBase MostSignificantBit = static_cast(1) << ((sizeof(TBase) * 8) - 1); #endif for (uint32_t i = 0; i < count; i++) @@ -1608,25 +1607,23 @@ void EvaluateSimdCvtVectorToMask(simdmask_t* result, TSimd arg0) TBase input0; memcpy(&input0, &arg0.u8[i * sizeof(TBase)], sizeof(TBase)); - if ((input0 & significantBit) != 0) - { #if defined(TARGET_XARCH) - // For xarch we have count sequential bits to write - // depending on if the corresponding the input element - // has its most significant bit set - + // For xarch we have count sequential bits to write depending on if the + // corresponding the input element has its most significant bit set + if ((input0 & MostSignificantBit) != 0) + { mask |= static_cast(1) << i; + } #elif defined(TARGET_ARM64) - // For Arm64 we have count total bits to write, but - // they are sizeof(TBase) bits apart. We set - // depending on if the corresponding input element - // has its least significant bit set - + // For Arm64 we have count total bits to write, but they are sizeof(TBase) bits + // apart. We set depending on if the corresponding input element is non zero + if (input0 != 0) + { mask |= static_cast(1) << (i * sizeof(TBase)); + } #else - unreached(); + unreached(); #endif - } } memcpy(&result->u8[0], &mask, sizeof(uint64_t)); diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 79f596806cfb1a..52ac0df5d4cb71 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -9145,6 +9145,30 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunTernary( { // (y & x) | (z & ~x) +#if defined(TARGET_ARM64) + if (ni == NI_Sve_ConditionalSelect) + { + assert(TypeOfVN(arg0VN) == TYP_MASK); + assert(type == TYP_SIMD16); + + ValueNum maskVNSimd = EvaluateSimdCvtMaskToVector(this, type, baseType, arg0VN); + simd16_t maskVal = ::GetConstantSimd16(this, baseType, maskVNSimd); + + simd16_t arg1 = ::GetConstantSimd16(this, baseType, arg1VN); + simd16_t arg2 = ::GetConstantSimd16(this, baseType, arg2VN); + + simd16_t result = {}; + EvaluateBinarySimd(GT_AND, false, baseType, &result, arg1, maskVal); + ValueNum trueVN = VNForSimd16Con(result); + + result = {}; + EvaluateBinarySimd(GT_AND_NOT, false, baseType, &result, arg2, maskVal); + ValueNum falseVN = VNForSimd16Con(result); + + return EvaluateBinarySimd(this, GT_OR, false, type, baseType, trueVN, falseVN); + } +#endif // TARGET_ARM64 + ValueNum trueVN = EvaluateBinarySimd(this, GT_AND, false, type, baseType, arg1VN, arg0VN); ValueNum falseVN = EvaluateBinarySimd(this, GT_AND_NOT, false, type, baseType, arg2VN, arg0VN); diff --git a/src/tests/JIT/opt/SVE/ConditionalSelectConstants.cs b/src/tests/JIT/opt/SVE/ConditionalSelectConstants.cs new file mode 100644 index 00000000000000..78f7a31c8f2b44 --- /dev/null +++ b/src/tests/JIT/opt/SVE/ConditionalSelectConstants.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Numerics; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using Xunit; + +public class ConditionalSelectConstants +{ + [MethodImpl(MethodImplOptions.NoInlining)] + [Fact] + public static int TestConditionalSelectConstants() + { + bool fail = false; + + if (Sve.IsSupported) + { + var r1 = Sve.AddAcross(ConditionalSelect1CC()); + Console.WriteLine(r1[0]); + if (r1[0] != 15) + { + fail = true; + } + + var r2 = Sve.AddAcross(ConditionalSelect1FT()); + Console.WriteLine(r2[0]); + if (r2[0] != -3) + { + fail = true; + } + + var r3 = Sve.AddAcross(ConditionalSelect16TF()); + Console.WriteLine(r3[0]); + if (r3[0] != 4080) + { + fail = true; + } + + var r4 = Sve.AddAcross(ConditionalSelect2CT()); + Console.WriteLine(r4[0]); + if (r4[0] != 16) + { + fail = true; + } + + var r5 = ConditionalSelectConsts(); + Console.WriteLine(r5); + if (r5 != 5) + { + fail = true; + } + + var r6 = ConditionalSelectConstsNoMaskPattern(); + Console.WriteLine(r6); + if (r6 != false) + { + fail = true; + } + + var r7 = ConditionalSelectConstsNoMaskPattern2(); + Console.WriteLine(r7); + if (r7 != 0) + { + fail = true; + } + } + + if (fail) + { + return 101; + } + return 100; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector ConditionalSelect1CC() + { + return Sve.ConditionalSelect( + Sve.CreateTrueMaskInt32(SveMaskPattern.VectorCount1), + Vector.Create(3), + Vector.Create(4) + ); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector ConditionalSelect1FT() + { + return Sve.ConditionalSelect( + Sve.CreateTrueMaskInt32(SveMaskPattern.VectorCount1), + Sve.CreateFalseMaskInt32(), + Sve.CreateTrueMaskInt32() + ); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector ConditionalSelect16TF() + { + return Sve.ConditionalSelect( + Sve.CreateTrueMaskByte(SveMaskPattern.VectorCount16), + Sve.CreateTrueMaskByte(), + Sve.CreateFalseMaskByte() + ); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static Vector ConditionalSelect2CT() + { + return Sve.ConditionalSelect( + Sve.CreateTrueMaskInt32(SveMaskPattern.VectorCount2), + Vector.Create(9), + Sve.CreateTrueMaskInt32() + ); + } + + // Valuenum will optimise away the ConditionalSelect. + + [MethodImpl(MethodImplOptions.NoInlining)] + static sbyte ConditionalSelectConsts() + { + var vec = Sve.ConditionalSelect(Vector128.CreateScalar((sbyte)49).AsVector(), + Vector128.CreateScalar((sbyte)0).AsVector(), + Vector.Create(107)); + return Sve.ConditionalExtractLastActiveElement(Vector128.CreateScalar((sbyte)0).AsVector(), 5, vec); + } + + // Valuenum will optimise away the ConditionalSelect. + // vr0 is a constant mask, but is not a known pattern. + + [MethodImpl(MethodImplOptions.NoInlining)] + public static bool ConditionalSelectConstsNoMaskPattern() + { + var vr2 = Vector128.CreateScalar(5653592783208606001L).AsVector(); + var vr3 = Vector128.CreateScalar(6475288982576452694L).AsVector(); + var vr4 = Vector.Create(1); + var vr0 = Sve.ConditionalSelect(vr2, vr3, vr4); + var vr5 = Vector128.CreateScalar((long)0).AsVector(); + return Sve.TestFirstTrue(vr0, vr5); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static int ConditionalSelectConstsNoMaskPattern2() + { + var vr0 = Sve.ConditionalSelect(Vector128.CreateScalar((short)1).AsVector(), Vector.Create(0), Vector.Create(1)); + return Sve.ConditionalExtractAfterLastActiveElement(vr0, 1, Vector.Create(0)); + } +} diff --git a/src/tests/JIT/opt/SVE/ConditionalSelectConstants.csproj b/src/tests/JIT/opt/SVE/ConditionalSelectConstants.csproj new file mode 100644 index 00000000000000..610299ffff1175 --- /dev/null +++ b/src/tests/JIT/opt/SVE/ConditionalSelectConstants.csproj @@ -0,0 +1,11 @@ + + + true + True + + + + + + + diff --git a/src/tests/JIT/opt/SVE/ConstantVectors.cs b/src/tests/JIT/opt/SVE/ConstantVectors.cs new file mode 100644 index 00000000000000..4a0f269bcc0f58 --- /dev/null +++ b/src/tests/JIT/opt/SVE/ConstantVectors.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Numerics; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using Xunit; + +public class ConstanVectors +{ + [MethodImpl(MethodImplOptions.NoInlining)] + [Fact] + public static int TestConstanVectors() + { + bool fail = false; + + if (Sve.IsSupported) + { + var r1 = SaturatingDecrementByActiveElementCountConst(); + Console.WriteLine(r1); + if (r1 != 18446744073709551615) + { + fail = true; + } + } + + if (fail) + { + return 101; + } + return 100; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public static ulong SaturatingDecrementByActiveElementCountConst() + { + var vr5 = Vector128.CreateScalar(14610804860246336108UL).AsVector(); + return Sve.SaturatingDecrementByActiveElementCount(0UL, vr5); + } +} diff --git a/src/tests/JIT/opt/SVE/ConstantVectors.csproj b/src/tests/JIT/opt/SVE/ConstantVectors.csproj new file mode 100644 index 00000000000000..610299ffff1175 --- /dev/null +++ b/src/tests/JIT/opt/SVE/ConstantVectors.csproj @@ -0,0 +1,11 @@ + + + true + True + + + + + + +