-
-
Notifications
You must be signed in to change notification settings - Fork 887
Add AVX2 versions of CombinedShannonEntropy #1848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
bab85d4
cc430cc
ed8bd61
b1df6a9
32b97f4
0fc3ce7
93f06bb
265be5f
110ff3d
5403fbd
f4fe9ba
9c95389
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| // Licensed under the Apache License, Version 2.0. | ||
|
|
||
| using System; | ||
| using System.Numerics; | ||
| using System.Runtime.CompilerServices; | ||
| using System.Runtime.InteropServices; | ||
| using SixLabors.ImageSharp.Memory; | ||
|
|
@@ -759,28 +760,184 @@ public static void BundleColorMap(Span<byte> row, int width, int xBits, Span<uin | |
| /// <returns>Shanon entropy.</returns> | ||
| public static float CombinedShannonEntropy(Span<int> x, Span<int> y) | ||
| { | ||
| double retVal = 0.0d; | ||
| uint sumX = 0, sumXY = 0; | ||
| for (int i = 0; i < 256; i++) | ||
| #if SUPPORTS_RUNTIME_INTRINSICS | ||
| if (Avx2.IsSupported) | ||
| { | ||
| uint xi = (uint)x[i]; | ||
| if (xi != 0) | ||
| double retVal = 0.0d; | ||
| Span<int> tmp = stackalloc int[8]; | ||
| ref int xRef = ref MemoryMarshal.GetReference(x); | ||
| ref int yRef = ref MemoryMarshal.GetReference(y); | ||
| Vector256<int> sumXY256 = Vector256<int>.Zero; | ||
| Vector256<int> sumX256 = Vector256<int>.Zero; | ||
| ref int tmpRef = ref MemoryMarshal.GetReference(tmp); | ||
brianpopow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (nint i = 0; i < 256; i += 8) | ||
| { | ||
| uint xy = xi + (uint)y[i]; | ||
| sumX += xi; | ||
| retVal -= FastSLog2(xi); | ||
| sumXY += xy; | ||
| retVal -= FastSLog2(xy); | ||
| Vector256<int> xVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref xRef, i)); | ||
| Vector256<int> yVec = Unsafe.As<int, Vector256<int>>(ref Unsafe.Add(ref yRef, i)); | ||
|
|
||
| // Check if any X is non-zero: this actually provides a speedup as X is usually sparse. | ||
| int mask = Avx2.MoveMask(Avx2.CompareEqual(xVec, Vector256<int>.Zero).AsByte()); | ||
| if (mask != -1) | ||
| { | ||
| Vector256<int> xy256 = Avx2.Add(xVec, yVec); | ||
| sumXY256 = Avx2.Add(sumXY256, xy256); | ||
| sumX256 = Avx2.Add(sumX256, xVec); | ||
|
|
||
| // Analyze the different X + Y. | ||
| Unsafe.As<int, Vector256<int>>(ref tmpRef) = xy256; | ||
| if (tmpRef != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)tmpRef); | ||
| if (Unsafe.Add(ref xRef, i) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 1) != 0) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I have tried to put those repeating if statements into own method calls, but profiling has shown that this makes it actually slower even with Aggressive Inlining.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't it be a loop? It looks incremental 0-7 to me. However, my money says there's something clever that can be done with masking here to determine if each element != 0 and apply the diff as a single operation. I'm sure I've seen similar before.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simd? You can check each element in a vector without if checks at all for You can also remove if checks with log2 precalculation for each case and simply multiplying that log2 vectors with comparison mask values. This may or may not be faster than if-checks - depends ong FastSLog2 implementation.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original SSE2 version of this uses macros. I tried to keep it as similar as possible.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess loop can further degrade branch predictor history. Masking is still possible but that would require some serious code rewrite for AVX branch.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I honestly have no idea how webp works but as far as I understand checked value is actually pretty random, adding for-loop on top of it may screw up branch predictor with yet another stable (always true for 8 iterations) if-check - that's what's most likely causing performance drop. All of extra if-checks can be removed via simd masks but it's a question whether it'd be faster than current implementation.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had a previous SSE2 version which used masks, if you look at commit: ed8bd61, but this was not better then the current approach. Maybe, if we could create a AVX version of that, it could be better, but I am also not sure if it really can beat the current one.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually easier to describe my proposal in code than in human language. I'll probably try to implement it after this gets merged.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @br3aker thanks, but my advice would be, to only do it, if you think it would be easy for you and not to much work. As I said before, I am really unsure how much we can gain from this. |
||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 1)); | ||
| if (Unsafe.Add(ref xRef, i + 1) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 1)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 2) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 2)); | ||
| if (Unsafe.Add(ref xRef, i + 2) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 2)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 3) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 3)); | ||
| if (Unsafe.Add(ref xRef, i + 3) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 3)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 4) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 4)); | ||
| if (Unsafe.Add(ref xRef, i + 4) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 4)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 5) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 5)); | ||
| if (Unsafe.Add(ref xRef, i + 5) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 5)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 6) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 6)); | ||
| if (Unsafe.Add(ref xRef, i + 6) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 6)); | ||
| } | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref tmpRef, 7) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref tmpRef, 7)); | ||
| if (Unsafe.Add(ref xRef, i + 7) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref xRef, i + 7)); | ||
| } | ||
| } | ||
| } | ||
| else | ||
| { | ||
| // X is fully 0, so only deal with Y. | ||
| sumXY256 = Avx2.Add(sumXY256, yVec); | ||
|
|
||
| if (Unsafe.Add(ref yRef, i) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 1) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 1)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 2) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 2)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 3) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 3)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 4) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 4)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 5) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 5)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 6) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 6)); | ||
| } | ||
|
|
||
| if (Unsafe.Add(ref yRef, i + 7) != 0) | ||
| { | ||
| retVal -= FastSLog2((uint)Unsafe.Add(ref yRef, i + 7)); | ||
| } | ||
| } | ||
| } | ||
| else if (y[i] != 0) | ||
|
|
||
| // Sum up sumX256 to get sumX and sum up sumXY256 to get sumXY. | ||
| int sumX = Numerics.ReduceSum(sumX256); | ||
| int sumXY = Numerics.ReduceSum(sumXY256); | ||
|
|
||
| retVal += FastSLog2((uint)sumX) + FastSLog2((uint)sumXY); | ||
|
|
||
| return (float)retVal; | ||
| } | ||
| else | ||
| #endif | ||
| { | ||
| double retVal = 0.0d; | ||
| uint sumX = 0, sumXY = 0; | ||
| for (int i = 0; i < 256; i++) | ||
| { | ||
| sumXY += (uint)y[i]; | ||
| retVal -= FastSLog2((uint)y[i]); | ||
| uint xi = (uint)x[i]; | ||
| if (xi != 0) | ||
| { | ||
| uint xy = xi + (uint)y[i]; | ||
| sumX += xi; | ||
| retVal -= FastSLog2(xi); | ||
| sumXY += xy; | ||
| retVal -= FastSLog2(xy); | ||
| } | ||
| else if (y[i] != 0) | ||
| { | ||
| sumXY += (uint)y[i]; | ||
| retVal -= FastSLog2((uint)y[i]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| retVal += FastSLog2(sumX) + FastSLog2(sumXY); | ||
| return (float)retVal; | ||
| retVal += FastSLog2(sumX) + FastSLog2(sumXY); | ||
| return (float)retVal; | ||
| } | ||
| } | ||
|
|
||
| [MethodImpl(InliningOptions.ShortMethod)] | ||
|
|
@@ -836,6 +993,7 @@ public static void ColorCodeToMultipliers(uint colorCode, ref Vp8LMultipliers m) | |
| private static float FastSLog2Slow(uint v) | ||
| { | ||
| DebugGuard.MustBeGreaterThanOrEqualTo<uint>(v, LogLookupIdxMax, nameof(v)); | ||
|
|
||
| if (v < ApproxLogWithCorrectionMax) | ||
| { | ||
| int logCnt = 0; | ||
|
|
@@ -865,7 +1023,7 @@ private static float FastSLog2Slow(uint v) | |
|
|
||
| private static float FastLog2Slow(uint v) | ||
| { | ||
| Guard.MustBeGreaterThanOrEqualTo(v, LogLookupIdxMax, nameof(v)); | ||
| DebugGuard.MustBeGreaterThanOrEqualTo<uint>(v, LogLookupIdxMax, nameof(v)); | ||
|
|
||
| if (v < ApproxLogWithCorrectionMax) | ||
| { | ||
|
|
||

Uh oh!
There was an error while loading. Please reload this page.