Skip to content

Commit

Permalink
Improve IndexOfAnyAsciiSearcher ARM throughput (#78739)
Browse files Browse the repository at this point in the history
* Improve IndexOfAnyAsciiSearcher ARM throughput

Also solves an edge-case bug for byte overloads on ARM64 where we'd get false positives (negatives for Except) when haystack values were exactly 128 above a value in the needle

* Remove type name asserts
  • Loading branch information
MihaZupan authored Nov 23, 2022
1 parent 8020db5 commit 2b87d85
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 10 deletions.
44 changes: 44 additions & 0 deletions src/libraries/System.Memory/tests/Span/IndexOfAny.byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,50 @@ static int IndexOfAnyReferenceImpl(ReadOnlySpan<byte> searchSpace, ReadOnlySpan<
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public static void AsciiNeedle_ProperlyHandlesEdgeCases_Byte(bool needleContainsZero)
{
// There is some special handling we have to do for ASCII needles to properly filter out non-ASCII results
ReadOnlySpan<byte> needleValues = needleContainsZero ? "AEIOU\0"u8 : "AEIOU!"u8;
IndexOfAnyValues<byte> needle = IndexOfAnyValues.Create(needleValues);

ReadOnlySpan<byte> repeatingHaystack = "AaAaAaAaAaAa"u8;
Assert.Equal(0, repeatingHaystack.IndexOfAny(needle));
Assert.Equal(1, repeatingHaystack.IndexOfAnyExcept(needle));
Assert.Equal(10, repeatingHaystack.LastIndexOfAny(needle));
Assert.Equal(11, repeatingHaystack.LastIndexOfAnyExcept(needle));

ReadOnlySpan<byte> haystackWithZeroes = "Aa\0Aa\0Aa\0"u8;
Assert.Equal(0, haystackWithZeroes.IndexOfAny(needle));
Assert.Equal(1, haystackWithZeroes.IndexOfAnyExcept(needle));
Assert.Equal(needleContainsZero ? 8 : 6, haystackWithZeroes.LastIndexOfAny(needle));
Assert.Equal(needleContainsZero ? 7 : 8, haystackWithZeroes.LastIndexOfAnyExcept(needle));

Span<byte> haystackWithOffsetNeedle = new byte[100];
for (int i = 0; i < haystackWithOffsetNeedle.Length; i++)
{
haystackWithOffsetNeedle[i] = (byte)(128 + needleValues[i % needleValues.Length]);
}

Assert.Equal(-1, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(-1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));

// Mix matching characters back in
for (int i = 0; i < haystackWithOffsetNeedle.Length; i += 3)
{
haystackWithOffsetNeedle[i] = needleValues[i % needleValues.Length];
}

Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(1, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 2, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));
}

private static int IndexOf(Span<byte> span, byte value)
{
int index = span.IndexOf(value);
Expand Down
66 changes: 66 additions & 0 deletions src/libraries/System.Memory/tests/Span/IndexOfAny.char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,72 @@ static int IndexOfAnyReferenceImpl(ReadOnlySpan<char> searchSpace, ReadOnlySpan<
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public static void AsciiNeedle_ProperlyHandlesEdgeCases_Char(bool needleContainsZero)
{
// There is some special handling we have to do for ASCII needles to properly filter out non-ASCII results
ReadOnlySpan<char> needleValues = needleContainsZero ? "AEIOU\0" : "AEIOU!";
IndexOfAnyValues<char> needle = IndexOfAnyValues.Create(needleValues);

ReadOnlySpan<char> repeatingHaystack = "AaAaAaAaAaAa";
Assert.Equal(0, repeatingHaystack.IndexOfAny(needle));
Assert.Equal(1, repeatingHaystack.IndexOfAnyExcept(needle));
Assert.Equal(10, repeatingHaystack.LastIndexOfAny(needle));
Assert.Equal(11, repeatingHaystack.LastIndexOfAnyExcept(needle));

ReadOnlySpan<char> haystackWithZeroes = "Aa\0Aa\0Aa\0";
Assert.Equal(0, haystackWithZeroes.IndexOfAny(needle));
Assert.Equal(1, haystackWithZeroes.IndexOfAnyExcept(needle));
Assert.Equal(needleContainsZero ? 8 : 6, haystackWithZeroes.LastIndexOfAny(needle));
Assert.Equal(needleContainsZero ? 7 : 8, haystackWithZeroes.LastIndexOfAnyExcept(needle));

Span<char> haystackWithOffsetNeedle = new char[100];
for (int i = 0; i < haystackWithOffsetNeedle.Length; i++)
{
haystackWithOffsetNeedle[i] = (char)(128 + needleValues[i % needleValues.Length]);
}

Assert.Equal(-1, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(-1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));

// Mix matching characters back in
for (int i = 0; i < haystackWithOffsetNeedle.Length; i += 3)
{
haystackWithOffsetNeedle[i] = needleValues[i % needleValues.Length];
}

Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(1, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 2, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));

// With chars, the lower byte could be matching, but we have to check that the higher byte is also 0
for (int i = 0; i < haystackWithOffsetNeedle.Length; i++)
{
haystackWithOffsetNeedle[i] = (char)(((i + 1) * 256) + needleValues[i % needleValues.Length]);
}

Assert.Equal(-1, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(-1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));

// Mix matching characters back in
for (int i = 0; i < haystackWithOffsetNeedle.Length; i += 3)
{
haystackWithOffsetNeedle[i] = needleValues[i % needleValues.Length];
}

Assert.Equal(0, haystackWithOffsetNeedle.IndexOfAny(needle));
Assert.Equal(1, haystackWithOffsetNeedle.IndexOfAnyExcept(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 1, haystackWithOffsetNeedle.LastIndexOfAny(needle));
Assert.Equal(haystackWithOffsetNeedle.Length - 2, haystackWithOffsetNeedle.LastIndexOfAnyExcept(needle));
}

private static int IndexOf(Span<char> span, char value)
{
int index = span.IndexOf(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,26 +490,23 @@ private static Vector128<byte> IndexOfAnyLookup<TNegator, TOptimizations>(Vector
// X86: Downcast every character using saturation.
// - Values <= 32767 result in min(value, 255).
// - Values > 32767 result in 0. Because of this we must do more work to handle needles that contain 0.
// ARM64: Take the low byte of each character.
// - All values result in (value & 0xFF).
// ARM64: Do narrowing saturation over unsigned values.
// - All values result in min(value, 255)
Vector128<byte> source = Sse2.IsSupported
? Sse2.PackUnsignedSaturate(source0, source1)
: AdvSimd.Arm64.UnzipEven(source0.AsByte(), source1.AsByte());
: AdvSimd.ExtractNarrowingSaturateUpper(AdvSimd.ExtractNarrowingSaturateLower(source0.AsUInt16()), source1.AsUInt16());

Vector128<byte> result = IndexOfAnyLookupCore(source, bitmapLookup);

// On ARM64, we ignored the high byte of every character when packing (see above).
// The 'result' can therefore contain false positives - e.g. 0x141 would match 0x41 ('A').
// On X86, PackUnsignedSaturate resulted in values becoming 0 for inputs above 32767.
// Any value above 32767 would therefore match against 0. If 0 is present in the needle, we must clear the false positives.
// In both cases, we can correct the result by clearing any bits that matched with a non-ascii source character.
if (AdvSimd.Arm64.IsSupported || TOptimizations.NeedleContainsZero)
if (TOptimizations.NeedleContainsZero)
{
Debug.Assert(Sse2.IsSupported);
Vector128<short> ascii0 = Vector128.LessThan(source0.AsUInt16(), Vector128.Create((ushort)128)).AsInt16();
Vector128<short> ascii1 = Vector128.LessThan(source1.AsUInt16(), Vector128.Create((ushort)128)).AsInt16();
Vector128<byte> ascii = Sse2.IsSupported
? Sse2.PackSignedSaturate(ascii0, ascii1).AsByte()
: AdvSimd.Arm64.UnzipEven(ascii0.AsByte(), ascii1.AsByte());
Vector128<byte> ascii = Sse2.PackSignedSaturate(ascii0, ascii1).AsByte();
result &= ascii;
}

Expand Down Expand Up @@ -542,7 +539,13 @@ private static Vector128<byte> IndexOfAnyLookupCore(Vector128<byte> source, Vect
? source
: source & Vector128.Create((byte)0xF);

Vector128<byte> highNibbles = Vector128.ShiftRightLogical(source.AsInt32(), 4).AsByte() & Vector128.Create((byte)0xF);
// On ARM, we have an instruction for an arithmetic right shift of 1-byte signed values.
// The shift will map values above 127 to values above 16, which the shuffle will then map to 0.
// This is how we exclude non-ASCII values from results on ARM.
// On X86, use a 4-byte value shift with AND 15 to emulate a 1-byte value logical shift.
Vector128<byte> highNibbles = AdvSimd.IsSupported
? AdvSimd.ShiftRightArithmetic(source.AsSByte(), 4).AsByte()
: Sse2.ShiftRightLogical(source.AsInt32(), 4).AsByte() & Vector128.Create((byte)0xF);

// The bitmapLookup represents a 8x16 table of bits, indicating whether a character is present in the needle.
// Lookup the rows via the lower nibble and the column via the higher nibble.
Expand Down

0 comments on commit 2b87d85

Please sign in to comment.