Skip to content

Commit

Permalink
ARM64-SVE: gathervector (#103159)
Browse files Browse the repository at this point in the history
* ARM64-SVE: gathervector

* HW_Flag_LowMaskedOperation

* Remove HW_Flag_BaseTypeFromFirstArg and add asserts

* Add invalid testing

* Add conditional select testing

* Use LSL_N for 64bit address variants

* Remove 32bit address base variants from API

* re-add commented out 32bit APIs

* Add comment on commented out methods

* Replace triple quotes with doubles
  • Loading branch information
a74nh authored Jun 12, 2024
1 parent 332fbb4 commit 3e8b786
Show file tree
Hide file tree
Showing 9 changed files with 1,443 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,12 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(varTypeIsSIMD(op2->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op2ClsHnd));
}
#elif defined(TARGET_ARM64)
if (intrinsic == NI_Sve_GatherVector)
{
assert(varTypeIsSIMD(op3->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
}
#endif
break;
}
Expand Down
36 changes: 36 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,42 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_GatherVector:
{
if (!varTypeIsSIMD(intrin.op2->gtType))
{
// GatherVector(Vector<T> mask, T* address, Vector<T2> indices)

assert(intrin.numOperands == 3);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);

if (baseSize == EA_8BYTE)
{
// Index is multiplied by 8
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt,
INS_SCALABLE_OPTS_LSL_N);
}
else
{
// Index is sign or zero extended to 64bits, then multiplied by 4
assert(baseSize == EA_4BYTE);
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt,
INS_SCALABLE_OPTS_MOD_N);
}
}
else
{
// GatherVector(Vector<T> mask, Vector<T2> addresses)

assert(intrin.numOperands == 2);
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, targetReg, op1Reg, op2Reg, 0, opt);
}

break;
}

case NI_Sve_ReverseElement:
// Use non-predicated version explicitly
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt, INS_SCALABLE_OPTS_UNPREDICATED);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated,
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, GatherVector, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, GetActiveElementCount, -1, 2, true, {INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp, INS_sve_cntp}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_ExplicitMaskedOperation)
HARDWARE_INTRINSIC(Sve, LeadingSignCount, -1, -1, false, {INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_sve_cls, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, LeadingZeroCount, -1, -1, false, {INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_sve_clz, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,120 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }


/// Unextended load

/// <summary>
/// svfloat64_t svld1_gather_[s64]index[_f64](svbool_t pg, const float64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svld1_gather[_u64base]_f64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svld1_gather_[u64]index[_f64](svbool_t pg, const float64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svld1_gather_[s32]index[_s32](svbool_t pg, const int32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

// <summary>
// svint32_t svld1_gather[_u32base]_s32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<int> GatherVector(Vector<int> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svld1_gather_[u32]index[_s32](svbool_t pg, const int32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather_[s64]index[_s64](svbool_t pg, const int64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather[_u64base]_s64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svld1_gather_[u64]index[_s64](svbool_t pg, const int64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svld1_gather_[s32]index[_f32](svbool_t pg, const float32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

// <summary>
// svfloat32_t svld1_gather[_u32base]_f32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<float> GatherVector(Vector<float> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svld1_gather_[u32]index[_f32](svbool_t pg, const float32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svld1_gather_[s32]index[_u32](svbool_t pg, const uint32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<int> indices) { throw new PlatformNotSupportedException(); }

// <summary>
// svuint32_t svld1_gather[_u32base]_u32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<uint> GatherVector(Vector<uint> mask, Vector<uint> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svld1_gather_[u32]index[_u32](svbool_t pg, const uint32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<uint> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather_[s64]index[_u64](svbool_t pg, const uint64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<long> indices) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather[_u64base]_u64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, Vector<ulong> addresses) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svld1_gather_[u64]index[_u64](svbool_t pg, const uint64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<ulong> indices) { throw new PlatformNotSupportedException(); }


/// Count set predicate bits

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,120 @@ internal Arm64() { }
public static unsafe Vector<float> FusedMultiplySubtractNegated(Vector<float> minuend, Vector<float> left, Vector<float> right) => FusedMultiplySubtractNegated(minuend, left, right);


/// Unextended load

/// <summary>
/// svfloat64_t svld1_gather_[s64]index[_f64](svbool_t pg, const float64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svfloat64_t svld1_gather[_u64base]_f64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svfloat64_t svld1_gather_[u64]index[_f64](svbool_t pg, const float64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<double> GatherVector(Vector<double> mask, double* address, Vector<ulong> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint32_t svld1_gather_[s32]index[_s32](svbool_t pg, const int32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<int> indices) => GatherVector(mask, address, indices);

// <summary>
// svint32_t svld1_gather[_u32base]_s32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<int> GatherVector(Vector<int> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svint32_t svld1_gather_[u32]index[_s32](svbool_t pg, const int32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<int> GatherVector(Vector<int> mask, int* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint64_t svld1_gather_[s64]index[_s64](svbool_t pg, const int64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svint64_t svld1_gather[_u64base]_s64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svint64_t svld1_gather_[u64]index[_s64](svbool_t pg, const int64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<long> GatherVector(Vector<long> mask, long* address, Vector<ulong> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svfloat32_t svld1_gather_[s32]index[_f32](svbool_t pg, const float32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<int> indices) => GatherVector(mask, address, indices);

// <summary>
// svfloat32_t svld1_gather[_u32base]_f32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<float> GatherVector(Vector<float> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svfloat32_t svld1_gather_[u32]index[_f32](svbool_t pg, const float32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<float> GatherVector(Vector<float> mask, float* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint32_t svld1_gather_[s32]index[_u32](svbool_t pg, const uint32_t *base, svint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, SXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<int> indices) => GatherVector(mask, address, indices);

// <summary>
// svuint32_t svld1_gather[_u32base]_u32(svbool_t pg, svuint32_t bases)
// LD1W Zresult.S, Pg/Z, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe Vector<uint> GatherVector(Vector<uint> mask, Vector<uint> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svuint32_t svld1_gather_[u32]index[_u32](svbool_t pg, const uint32_t *base, svuint32_t indices)
/// LD1W Zresult.S, Pg/Z, [Xbase, Zindices.S, UXTW #2]
/// </summary>
public static unsafe Vector<uint> GatherVector(Vector<uint> mask, uint* address, Vector<uint> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint64_t svld1_gather_[s64]index[_u64](svbool_t pg, const uint64_t *base, svint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<long> indices) => GatherVector(mask, address, indices);

/// <summary>
/// svuint64_t svld1_gather[_u64base]_u64(svbool_t pg, svuint64_t bases)
/// LD1D Zresult.D, Pg/Z, [Zbases.D, #0]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, Vector<ulong> addresses) => GatherVector(mask, addresses);

/// <summary>
/// svuint64_t svld1_gather_[u64]index[_u64](svbool_t pg, const uint64_t *base, svuint64_t indices)
/// LD1D Zresult.D, Pg/Z, [Xbase, Zindices.D, LSL #3]
/// </summary>
public static unsafe Vector<ulong> GatherVector(Vector<ulong> mask, ulong* address, Vector<ulong> indices) => GatherVector(mask, address, indices);


/// Count set predicate bits

/// <summary>
Expand Down
Loading

0 comments on commit 3e8b786

Please sign in to comment.