Skip to content

Commit

Permalink
Small cleanup of nnue_feature_transformer.h
Browse files Browse the repository at this point in the history
Passed Non-regression STC:
LLR: 2.94 (-2.94,2.94) <-1.75,0.25>
Total: 285760 W: 73716 L: 73768 D: 138276
Ptnml(0-2): 777, 30775, 79851, 30677, 800
https://tests.stockfishchess.org/tests/view/676f78681a2f267f205485aa

closes #5745

No functional change
  • Loading branch information
xu-shawn authored and Disservin committed Jan 12, 2025
1 parent d1a1ff4 commit c47e6fc
Showing 1 changed file with 120 additions and 115 deletions.
235 changes: 120 additions & 115 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ using psqt_vec_t = int32x4_t;
#endif


// Compute optimal SIMD register count for feature transformer accumulation.
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions>
class SIMDTiling {
#ifdef VECTOR

// Compute optimal SIMD register count for feature transformer accumulation.

// We use __m* types as template arguments, which causes GCC to emit warnings
// about losing some attribute information. This is irrelevant to us as we
// only take their size, so the following pragma are harmless.
Expand All @@ -158,33 +158,47 @@ using psqt_vec_t = int32x4_t;
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif

template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
static constexpr int BestRegisterCount() {
#define RegisterSize sizeof(SIMDRegisterType)
#define LaneSize sizeof(LaneType)

static_assert(RegisterSize >= LaneSize);
static_assert(MaxRegisters <= NumRegistersSIMD);
static_assert(MaxRegisters > 0);
static_assert(NumRegistersSIMD > 0);
static_assert(RegisterSize % LaneSize == 0);
static_assert((NumLanes * LaneSize) % RegisterSize == 0);

const int ideal = (NumLanes * LaneSize) / RegisterSize;
if (ideal <= MaxRegisters)
return ideal;

// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
for (int divisor = MaxRegisters; divisor > 1; --divisor)
if (ideal % divisor == 0)
return divisor;

return 1;
}
template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
static constexpr int BestRegisterCount() {
constexpr std::size_t RegisterSize = sizeof(SIMDRegisterType);
constexpr std::size_t LaneSize = sizeof(LaneType);

static_assert(RegisterSize >= LaneSize);
static_assert(MaxRegisters <= NumRegistersSIMD);
static_assert(MaxRegisters > 0);
static_assert(NumRegistersSIMD > 0);
static_assert(RegisterSize % LaneSize == 0);
static_assert((NumLanes * LaneSize) % RegisterSize == 0);

const int ideal = (NumLanes * LaneSize) / RegisterSize;
if (ideal <= MaxRegisters)
return ideal;

// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
for (int divisor = MaxRegisters; divisor > 1; --divisor)
if (ideal % divisor == 0)
return divisor;

return 1;
}

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

public:
static constexpr int NumRegs =
BestRegisterCount<vec_t, WeightType, TransformedFeatureWidth, NumRegistersSIMD>();
static constexpr int NumPsqtRegs =
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();

static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;

static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
#endif
};


// Input feature converter
Expand All @@ -196,17 +210,7 @@ class FeatureTransformer {
static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;

private:
#ifdef VECTOR
static constexpr int NumRegs =
BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>();
static constexpr int NumPsqtRegs =
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();

static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
#endif
using Tiling = SIMDTiling<TransformedFeatureDimensions, HalfDimensions>;

public:
// Output type
Expand Down Expand Up @@ -478,8 +482,8 @@ class FeatureTransformer {
#ifdef VECTOR
// Gcc-10.2 unnecessarily spills AVX2 registers if this array
// is defined in the VECTOR code below, once in each branch.
vec_t acc[NumRegs];
psqt_vec_t psqt[NumPsqtRegs];
vec_t acc[Tiling::NumRegs];
psqt_vec_t psqt[Tiling::NumPsqtRegs];
#endif

const Square ksq = pos.square<KING>(Perspective);
Expand All @@ -504,14 +508,14 @@ class FeatureTransformer {
#ifdef VECTOR
if ((removed.size() == 1 || removed.size() == 2) && added.size() == 1)
{
auto accIn =
auto* accIn =
reinterpret_cast<const vec_t*>(&(computed->*accPtr).accumulation[Perspective][0]);
auto accOut = reinterpret_cast<vec_t*>(&(next->*accPtr).accumulation[Perspective][0]);
auto* accOut = reinterpret_cast<vec_t*>(&(next->*accPtr).accumulation[Perspective][0]);

const IndexType offsetR0 = HalfDimensions * removed[0];
auto columnR0 = reinterpret_cast<const vec_t*>(&weights[offsetR0]);
auto* columnR0 = reinterpret_cast<const vec_t*>(&weights[offsetR0]);
const IndexType offsetA = HalfDimensions * added[0];
auto columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);
auto* columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);

if (removed.size() == 1)
{
Expand All @@ -521,22 +525,22 @@ class FeatureTransformer {
else
{
const IndexType offsetR1 = HalfDimensions * removed[1];
auto columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]);
auto* columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]);

for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(vec_t); ++i)
accOut[i] = vec_sub_16(vec_add_16(accIn[i], columnA[i]),
vec_add_16(columnR0[i], columnR1[i]));
}

auto accPsqtIn = reinterpret_cast<const psqt_vec_t*>(
auto* accPsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][0]);
auto accPsqtOut =
auto* accPsqtOut =
reinterpret_cast<psqt_vec_t*>(&(next->*accPtr).psqtAccumulation[Perspective][0]);

const IndexType offsetPsqtR0 = PSQTBuckets * removed[0];
auto columnPsqtR0 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR0]);
auto* columnPsqtR0 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR0]);
const IndexType offsetPsqtA = PSQTBuckets * added[0];
auto columnPsqtA = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtA]);
auto* columnPsqtA = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtA]);

if (removed.size() == 1)
{
Expand All @@ -548,7 +552,8 @@ class FeatureTransformer {
else
{
const IndexType offsetPsqtR1 = PSQTBuckets * removed[1];
auto columnPsqtR1 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]);
auto* columnPsqtR1 =
reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]);

for (std::size_t i = 0;
i < PSQTBuckets * sizeof(PSQTWeightType) / sizeof(psqt_vec_t); ++i)
Expand All @@ -559,69 +564,69 @@ class FeatureTransformer {
}
else
{
for (IndexType i = 0; i < HalfDimensions / TileHeight; ++i)
for (IndexType i = 0; i < HalfDimensions / Tiling::TileHeight; ++i)
{
// Load accumulator
auto accTileIn = reinterpret_cast<const vec_t*>(
&(computed->*accPtr).accumulation[Perspective][i * TileHeight]);
for (IndexType j = 0; j < NumRegs; ++j)
auto* accTileIn = reinterpret_cast<const vec_t*>(
&(computed->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_load(&accTileIn[j]);

// Difference calculation for the deactivated features
for (const auto index : removed)
{
const IndexType offset = HalfDimensions * index + i * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < NumRegs; ++j)
const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_sub_16(acc[j], column[j]);
}

// Difference calculation for the activated features
for (const auto index : added)
{
const IndexType offset = HalfDimensions * index + i * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < NumRegs; ++j)
const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_add_16(acc[j], column[j]);
}

// Store accumulator
auto accTileOut = reinterpret_cast<vec_t*>(
&(next->*accPtr).accumulation[Perspective][i * TileHeight]);
for (IndexType j = 0; j < NumRegs; ++j)
auto* accTileOut = reinterpret_cast<vec_t*>(
&(next->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
vec_store(&accTileOut[j], acc[j]);
}

for (IndexType i = 0; i < PSQTBuckets / PsqtTileHeight; ++i)
for (IndexType i = 0; i < PSQTBuckets / Tiling::PsqtTileHeight; ++i)
{
// Load accumulator
auto accTilePsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
auto* accTilePsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_load_psqt(&accTilePsqtIn[j]);

// Difference calculation for the deactivated features
for (const auto index : removed)
{
const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_sub_psqt_32(psqt[j], columnPsqt[j]);
}

// Difference calculation for the activated features
for (const auto index : added)
{
const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_add_psqt_32(psqt[j], columnPsqt[j]);
}

// Store accumulator
auto accTilePsqtOut = reinterpret_cast<psqt_vec_t*>(
&(next->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
auto* accTilePsqtOut = reinterpret_cast<psqt_vec_t*>(
&(next->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
vec_store_psqt(&accTilePsqtOut[j], psqt[j]);
}
}
Expand Down Expand Up @@ -700,88 +705,88 @@ class FeatureTransformer {
accumulator.computed[Perspective] = true;

#ifdef VECTOR
vec_t acc[NumRegs];
psqt_vec_t psqt[NumPsqtRegs];
vec_t acc[Tiling::NumRegs];
psqt_vec_t psqt[Tiling::NumPsqtRegs];

for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
for (IndexType j = 0; j < HalfDimensions / Tiling::TileHeight; ++j)
{
auto accTile =
reinterpret_cast<vec_t*>(&accumulator.accumulation[Perspective][j * TileHeight]);
auto entryTile = reinterpret_cast<vec_t*>(&entry.accumulation[j * TileHeight]);
auto* accTile = reinterpret_cast<vec_t*>(
&accumulator.accumulation[Perspective][j * Tiling::TileHeight]);
auto* entryTile = reinterpret_cast<vec_t*>(&entry.accumulation[j * Tiling::TileHeight]);

for (IndexType k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = entryTile[k];

int i = 0;
for (; i < int(std::min(removed.size(), added.size())); ++i)
std::size_t i = 0;
for (; i < std::min(removed.size(), added.size()); ++i)
{
IndexType indexR = removed[i];
const IndexType offsetR = HalfDimensions * indexR + j * TileHeight;
auto columnR = reinterpret_cast<const vec_t*>(&weights[offsetR]);
const IndexType offsetR = HalfDimensions * indexR + j * Tiling::TileHeight;
auto* columnR = reinterpret_cast<const vec_t*>(&weights[offsetR]);
IndexType indexA = added[i];
const IndexType offsetA = HalfDimensions * indexA + j * TileHeight;
auto columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);
const IndexType offsetA = HalfDimensions * indexA + j * Tiling::TileHeight;
auto* columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_add_16(acc[k], vec_sub_16(columnA[k], columnR[k]));
}
for (; i < int(removed.size()); ++i)
for (; i < removed.size(); ++i)
{
IndexType index = removed[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_sub_16(acc[k], column[k]);
}
for (; i < int(added.size()); ++i)
for (; i < added.size(); ++i)
{
IndexType index = added[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_add_16(acc[k], column[k]);
}

for (IndexType k = 0; k < NumRegs; k++)
for (IndexType k = 0; k < Tiling::NumRegs; k++)
vec_store(&entryTile[k], acc[k]);
for (IndexType k = 0; k < NumRegs; k++)
for (IndexType k = 0; k < Tiling::NumRegs; k++)
vec_store(&accTile[k], acc[k]);
}

for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j)
{
auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
&accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
auto entryTilePsqt =
reinterpret_cast<psqt_vec_t*>(&entry.psqtAccumulation[j * PsqtTileHeight]);
auto* accTilePsqt = reinterpret_cast<psqt_vec_t*>(
&accumulator.psqtAccumulation[Perspective][j * Tiling::PsqtTileHeight]);
auto* entryTilePsqt =
reinterpret_cast<psqt_vec_t*>(&entry.psqtAccumulation[j * Tiling::PsqtTileHeight]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = entryTilePsqt[k];

for (int i = 0; i < int(removed.size()); ++i)
for (std::size_t i = 0; i < removed.size(); ++i)
{
IndexType index = removed[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
}
for (int i = 0; i < int(added.size()); ++i)
for (std::size_t i = 0; i < added.size(); ++i)
{
IndexType index = added[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
}

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
vec_store_psqt(&entryTilePsqt[k], psqt[k]);
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
vec_store_psqt(&accTilePsqt[k], psqt[k]);
}

Expand Down

0 comments on commit c47e6fc

Please sign in to comment.