Skip to content

Commit b976f0a

Browse files
MinetaSvondele
authored andcommitted
Move DotProd code into optimized affine layer
This patch moves the DotProd code into the propagation function which has sequential access optimization. To prove the speedup, the comparison is done without the sparse layer. With the sparse layer the effect is marginal (GCC 0.3%, LLVM/Clang 0.1%). For both tests, binary is compiled with GCC 14.1. Each test had 50 runs. Sparse layer included: ``` speedup = +0.0030 P(speedup > 0) = 1.0000 ``` Sparse layer excluded: ``` speedup = +0.0561 P(speedup > 0) = 1.0000 ``` closes #5520 No functional change
1 parent 8e560c4 commit b976f0a

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

src/nnue/layers/affine_transform.h

+28-30
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,26 @@
3939

4040
namespace Stockfish::Eval::NNUE::Layers {
4141

42+
#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)
43+
#define ENABLE_SEQ_OPT
44+
#endif
45+
4246
// Fallback implementation for older/other architectures.
4347
// Requires the input to be padded to at least 16 values.
44-
#if !defined(USE_SSSE3)
48+
#ifndef ENABLE_SEQ_OPT
49+
4550
template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>
4651
static void affine_transform_non_ssse3(std::int32_t* output,
4752
const std::int8_t* weights,
4853
const std::int32_t* biases,
4954
const std::uint8_t* input) {
50-
#if defined(USE_SSE2) || defined(USE_NEON_DOTPROD) || defined(USE_NEON)
55+
#if defined(USE_SSE2) || defined(USE_NEON)
5156
#if defined(USE_SSE2)
5257
// At least a multiple of 16, with SSE2.
5358
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
5459
const __m128i Zeros = _mm_setzero_si128();
5560
const auto inputVector = reinterpret_cast<const __m128i*>(input);
5661

57-
#elif defined(USE_NEON_DOTPROD)
58-
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
59-
const auto inputVector = reinterpret_cast<const int8x16_t*>(input);
60-
6162
#elif defined(USE_NEON)
6263
constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;
6364
const auto inputVector = reinterpret_cast<const int8x8_t*>(input);
@@ -91,16 +92,8 @@ static void affine_transform_non_ssse3(std::int32_t* output,
9192
sum = _mm_add_epi32(sum, sum_second_32);
9293
output[i] = _mm_cvtsi128_si32(sum);
9394

94-
#elif defined(USE_NEON_DOTPROD)
95-
int32x4_t sum = {biases[i]};
96-
const auto row = reinterpret_cast<const int8x16_t*>(&weights[offset]);
97-
for (IndexType j = 0; j < NumChunks; ++j)
98-
{
99-
sum = vdotq_s32(sum, inputVector[j], row[j]);
100-
}
101-
output[i] = vaddvq_s32(sum);
102-
10395
#elif defined(USE_NEON)
96+
10497
int32x4_t sum = {biases[i]};
10598
const auto row = reinterpret_cast<const int8x8_t*>(&weights[offset]);
10699
for (IndexType j = 0; j < NumChunks; ++j)
@@ -127,7 +120,8 @@ static void affine_transform_non_ssse3(std::int32_t* output,
127120
}
128121
#endif
129122
}
130-
#endif
123+
124+
#endif // !ENABLE_SEQ_OPT
131125

132126
template<IndexType InDims, IndexType OutDims>
133127
class AffineTransform {
@@ -162,7 +156,7 @@ class AffineTransform {
162156
}
163157

164158
static constexpr IndexType get_weight_index(IndexType i) {
165-
#if defined(USE_SSSE3)
159+
#ifdef ENABLE_SEQ_OPT
166160
return get_weight_index_scrambled(i);
167161
#else
168162
return i;
@@ -190,29 +184,28 @@ class AffineTransform {
190184
// Forward propagation
191185
void propagate(const InputType* input, OutputType* output) const {
192186

193-
#if defined(USE_SSSE3)
187+
#ifdef ENABLE_SEQ_OPT
194188

195189
if constexpr (OutputDimensions > 1)
196190
{
197-
198191
#if defined(USE_AVX512)
199192
using vec_t = __m512i;
200-
#define vec_setzero _mm512_setzero_si512
201193
#define vec_set_32 _mm512_set1_epi32
202194
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
203-
#define vec_hadd Simd::m512_hadd
204195
#elif defined(USE_AVX2)
205196
using vec_t = __m256i;
206-
#define vec_setzero _mm256_setzero_si256
207197
#define vec_set_32 _mm256_set1_epi32
208198
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
209-
#define vec_hadd Simd::m256_hadd
210199
#elif defined(USE_SSSE3)
211200
using vec_t = __m128i;
212-
#define vec_setzero _mm_setzero_si128
213201
#define vec_set_32 _mm_set1_epi32
214202
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
215-
#define vec_hadd Simd::m128_hadd
203+
#elif defined(USE_NEON_DOTPROD)
204+
using vec_t = int32x4_t;
205+
#define vec_set_32 vdupq_n_s32
206+
#define vec_add_dpbusd_32(acc, a, b) \
207+
Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
208+
vreinterpretq_s8_s32(b))
216209
#endif
217210

218211
static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
@@ -242,28 +235,33 @@ class AffineTransform {
242235
for (IndexType k = 0; k < NumRegs; ++k)
243236
outptr[k] = acc[k];
244237

245-
#undef vec_setzero
246238
#undef vec_set_32
247239
#undef vec_add_dpbusd_32
248-
#undef vec_hadd
249240
}
250241
else if constexpr (OutputDimensions == 1)
251242
{
252-
253243
// We cannot use AVX512 for the last layer because there are only 32 inputs
254244
// and the buffer is not padded to 64 elements.
255245
#if defined(USE_AVX2)
256246
using vec_t = __m256i;
257-
#define vec_setzero _mm256_setzero_si256
247+
#define vec_setzero() _mm256_setzero_si256()
258248
#define vec_set_32 _mm256_set1_epi32
259249
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
260250
#define vec_hadd Simd::m256_hadd
261251
#elif defined(USE_SSSE3)
262252
using vec_t = __m128i;
263-
#define vec_setzero _mm_setzero_si128
253+
#define vec_setzero() _mm_setzero_si128()
264254
#define vec_set_32 _mm_set1_epi32
265255
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
266256
#define vec_hadd Simd::m128_hadd
257+
#elif defined(USE_NEON_DOTPROD)
258+
using vec_t = int32x4_t;
259+
#define vec_setzero() vdupq_n_s32(0)
260+
#define vec_set_32 vdupq_n_s32
261+
#define vec_add_dpbusd_32(acc, a, b) \
262+
Simd::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \
263+
vreinterpretq_s8_s32(b))
264+
#define vec_hadd Simd::neon_m128_hadd
267265
#endif
268266

269267
const auto inputVector = reinterpret_cast<const vec_t*>(input);

0 commit comments

Comments
 (0)