diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 6192cd51e32..b047f62c409 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -146,10 +146,10 @@ using psqt_vec_t = int32x4_t; #endif +// Compute optimal SIMD register count for feature transformer accumulation. +template +class SIMDTiling { #ifdef VECTOR - - // Compute optimal SIMD register count for feature transformer accumulation. - // We use __m* types as template arguments, which causes GCC to emit warnings // about losing some attribute information. This is irrelevant to us as we // only take their size, so the following pragma are harmless. @@ -158,33 +158,47 @@ using psqt_vec_t = int32x4_t; #pragma GCC diagnostic ignored "-Wignored-attributes" #endif -template -static constexpr int BestRegisterCount() { - #define RegisterSize sizeof(SIMDRegisterType) - #define LaneSize sizeof(LaneType) - - static_assert(RegisterSize >= LaneSize); - static_assert(MaxRegisters <= NumRegistersSIMD); - static_assert(MaxRegisters > 0); - static_assert(NumRegistersSIMD > 0); - static_assert(RegisterSize % LaneSize == 0); - static_assert((NumLanes * LaneSize) % RegisterSize == 0); - - const int ideal = (NumLanes * LaneSize) / RegisterSize; - if (ideal <= MaxRegisters) - return ideal; - - // Look for the largest divisor of the ideal register count that is smaller than MaxRegisters - for (int divisor = MaxRegisters; divisor > 1; --divisor) - if (ideal % divisor == 0) - return divisor; - - return 1; -} + template + static constexpr int BestRegisterCount() { + constexpr std::size_t RegisterSize = sizeof(SIMDRegisterType); + constexpr std::size_t LaneSize = sizeof(LaneType); + + static_assert(RegisterSize >= LaneSize); + static_assert(MaxRegisters <= NumRegistersSIMD); + static_assert(MaxRegisters > 0); + static_assert(NumRegistersSIMD > 0); + static_assert(RegisterSize % LaneSize == 0); + static_assert((NumLanes * LaneSize) % RegisterSize == 0); + + const int ideal = (NumLanes * LaneSize) / RegisterSize; + if (ideal <= MaxRegisters) + return ideal; + + // Look for the largest divisor of the ideal register count that is smaller than MaxRegisters + for (int divisor = MaxRegisters; divisor > 1; --divisor) + if (ideal % divisor == 0) + return divisor; + + return 1; + } + #if defined(__GNUC__) #pragma GCC diagnostic pop #endif + + public: + static constexpr int NumRegs = + BestRegisterCount(); + static constexpr int NumPsqtRegs = + BestRegisterCount(); + + static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2; + static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4; + + static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions"); + static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets"); #endif +}; // Input feature converter @@ -196,17 +210,7 @@ class FeatureTransformer { static constexpr IndexType HalfDimensions = TransformedFeatureDimensions; private: -#ifdef VECTOR - static constexpr int NumRegs = - BestRegisterCount(); - static constexpr int NumPsqtRegs = - BestRegisterCount(); - - static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2; - static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4; - static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions"); - static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets"); -#endif + using Tiling = SIMDTiling; public: // Output type @@ -478,8 +482,8 @@ class FeatureTransformer { #ifdef VECTOR // Gcc-10.2 unnecessarily spills AVX2 registers if this array // is defined in the VECTOR code below, once in each branch. - vec_t acc[NumRegs]; - psqt_vec_t psqt[NumPsqtRegs]; + vec_t acc[Tiling::NumRegs]; + psqt_vec_t psqt[Tiling::NumPsqtRegs]; #endif const Square ksq = pos.square(Perspective); @@ -504,14 +508,14 @@ class FeatureTransformer { #ifdef VECTOR if ((removed.size() == 1 || removed.size() == 2) && added.size() == 1) { - auto accIn = + auto* accIn = reinterpret_cast(&(computed->*accPtr).accumulation[Perspective][0]); - auto accOut = reinterpret_cast(&(next->*accPtr).accumulation[Perspective][0]); + auto* accOut = reinterpret_cast(&(next->*accPtr).accumulation[Perspective][0]); const IndexType offsetR0 = HalfDimensions * removed[0]; - auto columnR0 = reinterpret_cast(&weights[offsetR0]); + auto* columnR0 = reinterpret_cast(&weights[offsetR0]); const IndexType offsetA = HalfDimensions * added[0]; - auto columnA = reinterpret_cast(&weights[offsetA]); + auto* columnA = reinterpret_cast(&weights[offsetA]); if (removed.size() == 1) { @@ -521,22 +525,22 @@ class FeatureTransformer { else { const IndexType offsetR1 = HalfDimensions * removed[1]; - auto columnR1 = reinterpret_cast(&weights[offsetR1]); + auto* columnR1 = reinterpret_cast(&weights[offsetR1]); for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(vec_t); ++i) accOut[i] = vec_sub_16(vec_add_16(accIn[i], columnA[i]), vec_add_16(columnR0[i], columnR1[i])); } - auto accPsqtIn = reinterpret_cast( + auto* accPsqtIn = reinterpret_cast( &(computed->*accPtr).psqtAccumulation[Perspective][0]); - auto accPsqtOut = + auto* accPsqtOut = reinterpret_cast(&(next->*accPtr).psqtAccumulation[Perspective][0]); const IndexType offsetPsqtR0 = PSQTBuckets * removed[0]; - auto columnPsqtR0 = reinterpret_cast(&psqtWeights[offsetPsqtR0]); + auto* columnPsqtR0 = reinterpret_cast(&psqtWeights[offsetPsqtR0]); const IndexType offsetPsqtA = PSQTBuckets * added[0]; - auto columnPsqtA = reinterpret_cast(&psqtWeights[offsetPsqtA]); + auto* columnPsqtA = reinterpret_cast(&psqtWeights[offsetPsqtA]); if (removed.size() == 1) { @@ -548,7 +552,8 @@ class FeatureTransformer { else { const IndexType offsetPsqtR1 = PSQTBuckets * removed[1]; - auto columnPsqtR1 = reinterpret_cast(&psqtWeights[offsetPsqtR1]); + auto* columnPsqtR1 = + reinterpret_cast(&psqtWeights[offsetPsqtR1]); for (std::size_t i = 0; i < PSQTBuckets * sizeof(PSQTWeightType) / sizeof(psqt_vec_t); ++i) @@ -559,69 +564,69 @@ class FeatureTransformer { } else { - for (IndexType i = 0; i < HalfDimensions / TileHeight; ++i) + for (IndexType i = 0; i < HalfDimensions / Tiling::TileHeight; ++i) { // Load accumulator - auto accTileIn = reinterpret_cast( - &(computed->*accPtr).accumulation[Perspective][i * TileHeight]); - for (IndexType j = 0; j < NumRegs; ++j) + auto* accTileIn = reinterpret_cast( + &(computed->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]); + for (IndexType j = 0; j < Tiling::NumRegs; ++j) acc[j] = vec_load(&accTileIn[j]); // Difference calculation for the deactivated features for (const auto index : removed) { - const IndexType offset = HalfDimensions * index + i * TileHeight; - auto column = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumRegs; ++j) + const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight; + auto* column = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < Tiling::NumRegs; ++j) acc[j] = vec_sub_16(acc[j], column[j]); } // Difference calculation for the activated features for (const auto index : added) { - const IndexType offset = HalfDimensions * index + i * TileHeight; - auto column = reinterpret_cast(&weights[offset]); - for (IndexType j = 0; j < NumRegs; ++j) + const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight; + auto* column = reinterpret_cast(&weights[offset]); + for (IndexType j = 0; j < Tiling::NumRegs; ++j) acc[j] = vec_add_16(acc[j], column[j]); } // Store accumulator - auto accTileOut = reinterpret_cast( - &(next->*accPtr).accumulation[Perspective][i * TileHeight]); - for (IndexType j = 0; j < NumRegs; ++j) + auto* accTileOut = reinterpret_cast( + &(next->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]); + for (IndexType j = 0; j < Tiling::NumRegs; ++j) vec_store(&accTileOut[j], acc[j]); } - for (IndexType i = 0; i < PSQTBuckets / PsqtTileHeight; ++i) + for (IndexType i = 0; i < PSQTBuckets / Tiling::PsqtTileHeight; ++i) { // Load accumulator - auto accTilePsqtIn = reinterpret_cast( - &(computed->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]); - for (std::size_t j = 0; j < NumPsqtRegs; ++j) + auto* accTilePsqtIn = reinterpret_cast( + &(computed->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]); + for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j) psqt[j] = vec_load_psqt(&accTilePsqtIn[j]); // Difference calculation for the deactivated features for (const auto index : removed) { - const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t j = 0; j < NumPsqtRegs; ++j) + const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight; + auto* columnPsqt = reinterpret_cast(&psqtWeights[offset]); + for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j) psqt[j] = vec_sub_psqt_32(psqt[j], columnPsqt[j]); } // Difference calculation for the activated features for (const auto index : added) { - const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t j = 0; j < NumPsqtRegs; ++j) + const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight; + auto* columnPsqt = reinterpret_cast(&psqtWeights[offset]); + for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j) psqt[j] = vec_add_psqt_32(psqt[j], columnPsqt[j]); } // Store accumulator - auto accTilePsqtOut = reinterpret_cast( - &(next->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]); - for (std::size_t j = 0; j < NumPsqtRegs; ++j) + auto* accTilePsqtOut = reinterpret_cast( + &(next->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]); + for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j) vec_store_psqt(&accTilePsqtOut[j], psqt[j]); } } @@ -700,88 +705,88 @@ class FeatureTransformer { accumulator.computed[Perspective] = true; #ifdef VECTOR - vec_t acc[NumRegs]; - psqt_vec_t psqt[NumPsqtRegs]; + vec_t acc[Tiling::NumRegs]; + psqt_vec_t psqt[Tiling::NumPsqtRegs]; - for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j) + for (IndexType j = 0; j < HalfDimensions / Tiling::TileHeight; ++j) { - auto accTile = - reinterpret_cast(&accumulator.accumulation[Perspective][j * TileHeight]); - auto entryTile = reinterpret_cast(&entry.accumulation[j * TileHeight]); + auto* accTile = reinterpret_cast( + &accumulator.accumulation[Perspective][j * Tiling::TileHeight]); + auto* entryTile = reinterpret_cast(&entry.accumulation[j * Tiling::TileHeight]); - for (IndexType k = 0; k < NumRegs; ++k) + for (IndexType k = 0; k < Tiling::NumRegs; ++k) acc[k] = entryTile[k]; - int i = 0; - for (; i < int(std::min(removed.size(), added.size())); ++i) + std::size_t i = 0; + for (; i < std::min(removed.size(), added.size()); ++i) { IndexType indexR = removed[i]; - const IndexType offsetR = HalfDimensions * indexR + j * TileHeight; - auto columnR = reinterpret_cast(&weights[offsetR]); + const IndexType offsetR = HalfDimensions * indexR + j * Tiling::TileHeight; + auto* columnR = reinterpret_cast(&weights[offsetR]); IndexType indexA = added[i]; - const IndexType offsetA = HalfDimensions * indexA + j * TileHeight; - auto columnA = reinterpret_cast(&weights[offsetA]); + const IndexType offsetA = HalfDimensions * indexA + j * Tiling::TileHeight; + auto* columnA = reinterpret_cast(&weights[offsetA]); - for (unsigned k = 0; k < NumRegs; ++k) + for (IndexType k = 0; k < Tiling::NumRegs; ++k) acc[k] = vec_add_16(acc[k], vec_sub_16(columnA[k], columnR[k])); } - for (; i < int(removed.size()); ++i) + for (; i < removed.size(); ++i) { IndexType index = removed[i]; - const IndexType offset = HalfDimensions * index + j * TileHeight; - auto column = reinterpret_cast(&weights[offset]); + const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight; + auto* column = reinterpret_cast(&weights[offset]); - for (unsigned k = 0; k < NumRegs; ++k) + for (IndexType k = 0; k < Tiling::NumRegs; ++k) acc[k] = vec_sub_16(acc[k], column[k]); } - for (; i < int(added.size()); ++i) + for (; i < added.size(); ++i) { IndexType index = added[i]; - const IndexType offset = HalfDimensions * index + j * TileHeight; - auto column = reinterpret_cast(&weights[offset]); + const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight; + auto* column = reinterpret_cast(&weights[offset]); - for (unsigned k = 0; k < NumRegs; ++k) + for (IndexType k = 0; k < Tiling::NumRegs; ++k) acc[k] = vec_add_16(acc[k], column[k]); } - for (IndexType k = 0; k < NumRegs; k++) + for (IndexType k = 0; k < Tiling::NumRegs; k++) vec_store(&entryTile[k], acc[k]); - for (IndexType k = 0; k < NumRegs; k++) + for (IndexType k = 0; k < Tiling::NumRegs; k++) vec_store(&accTile[k], acc[k]); } - for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j) + for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j) { - auto accTilePsqt = reinterpret_cast( - &accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]); - auto entryTilePsqt = - reinterpret_cast(&entry.psqtAccumulation[j * PsqtTileHeight]); + auto* accTilePsqt = reinterpret_cast( + &accumulator.psqtAccumulation[Perspective][j * Tiling::PsqtTileHeight]); + auto* entryTilePsqt = + reinterpret_cast(&entry.psqtAccumulation[j * Tiling::PsqtTileHeight]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) + for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k) psqt[k] = entryTilePsqt[k]; - for (int i = 0; i < int(removed.size()); ++i) + for (std::size_t i = 0; i < removed.size(); ++i) { IndexType index = removed[i]; - const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight; + auto* columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) + for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k) psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]); } - for (int i = 0; i < int(added.size()); ++i) + for (std::size_t i = 0; i < added.size(); ++i) { IndexType index = added[i]; - const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight; - auto columnPsqt = reinterpret_cast(&psqtWeights[offset]); + const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight; + auto* columnPsqt = reinterpret_cast(&psqtWeights[offset]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) + for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k) psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]); } - for (std::size_t k = 0; k < NumPsqtRegs; ++k) + for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k) vec_store_psqt(&entryTilePsqt[k], psqt[k]); - for (std::size_t k = 0; k < NumPsqtRegs; ++k) + for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k) vec_store_psqt(&accTilePsqt[k], psqt[k]); }