Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small cleanup of nnue_feature_transformer.h #5745

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 120 additions & 115 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ using psqt_vec_t = int32x4_t;
#endif


// Compute optimal SIMD register count for feature transformer accumulation.
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions>
class SIMDTiling {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mh i dunno about making a this a class, i don't really a see a need to extrac this here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also if VECTOR isn't defined we will just have an empty class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also if VECTOR isn't defined we will just have an empty class?

Yeah. the goal was to remove as much preprocessors directives from the main logic as possible.

#ifdef VECTOR

// Compute optimal SIMD register count for feature transformer accumulation.

// We use __m* types as template arguments, which causes GCC to emit warnings
// about losing some attribute information. This is irrelevant to us as we
// only take their size, so the following pragma are harmless.
Expand All @@ -158,33 +158,47 @@ using psqt_vec_t = int32x4_t;
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif

template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
static constexpr int BestRegisterCount() {
#define RegisterSize sizeof(SIMDRegisterType)
#define LaneSize sizeof(LaneType)

static_assert(RegisterSize >= LaneSize);
static_assert(MaxRegisters <= NumRegistersSIMD);
static_assert(MaxRegisters > 0);
static_assert(NumRegistersSIMD > 0);
static_assert(RegisterSize % LaneSize == 0);
static_assert((NumLanes * LaneSize) % RegisterSize == 0);

const int ideal = (NumLanes * LaneSize) / RegisterSize;
if (ideal <= MaxRegisters)
return ideal;

// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
for (int divisor = MaxRegisters; divisor > 1; --divisor)
if (ideal % divisor == 0)
return divisor;

return 1;
}
template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
static constexpr int BestRegisterCount() {
constexpr std::size_t RegisterSize = sizeof(SIMDRegisterType);
constexpr std::size_t LaneSize = sizeof(LaneType);

static_assert(RegisterSize >= LaneSize);
static_assert(MaxRegisters <= NumRegistersSIMD);
static_assert(MaxRegisters > 0);
static_assert(NumRegistersSIMD > 0);
static_assert(RegisterSize % LaneSize == 0);
static_assert((NumLanes * LaneSize) % RegisterSize == 0);

const int ideal = (NumLanes * LaneSize) / RegisterSize;
if (ideal <= MaxRegisters)
return ideal;

// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
for (int divisor = MaxRegisters; divisor > 1; --divisor)
if (ideal % divisor == 0)
return divisor;

return 1;
}

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

public:
static constexpr int NumRegs =
BestRegisterCount<vec_t, WeightType, TransformedFeatureWidth, NumRegistersSIMD>();
static constexpr int NumPsqtRegs =
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();

static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;

static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
#endif
};


// Input feature converter
Expand All @@ -196,17 +210,7 @@ class FeatureTransformer {
static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;

private:
#ifdef VECTOR
static constexpr int NumRegs =
BestRegisterCount<vec_t, WeightType, TransformedFeatureDimensions, NumRegistersSIMD>();
static constexpr int NumPsqtRegs =
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();

static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
#endif
using Tiling = SIMDTiling<TransformedFeatureDimensions, HalfDimensions>;

public:
// Output type
Expand Down Expand Up @@ -480,8 +484,8 @@ class FeatureTransformer {
#ifdef VECTOR
// Gcc-10.2 unnecessarily spills AVX2 registers if this array
// is defined in the VECTOR code below, once in each branch.
vec_t acc[NumRegs];
psqt_vec_t psqt[NumPsqtRegs];
vec_t acc[Tiling::NumRegs];
psqt_vec_t psqt[Tiling::NumPsqtRegs];
#endif

const Square ksq = pos.square<KING>(Perspective);
Expand All @@ -506,14 +510,14 @@ class FeatureTransformer {
#ifdef VECTOR
if ((removed.size() == 1 || removed.size() == 2) && added.size() == 1)
{
auto accIn =
auto* accIn =
reinterpret_cast<const vec_t*>(&(computed->*accPtr).accumulation[Perspective][0]);
auto accOut = reinterpret_cast<vec_t*>(&(next->*accPtr).accumulation[Perspective][0]);
auto* accOut = reinterpret_cast<vec_t*>(&(next->*accPtr).accumulation[Perspective][0]);

const IndexType offsetR0 = HalfDimensions * removed[0];
auto columnR0 = reinterpret_cast<const vec_t*>(&weights[offsetR0]);
auto* columnR0 = reinterpret_cast<const vec_t*>(&weights[offsetR0]);
const IndexType offsetA = HalfDimensions * added[0];
auto columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);
auto* columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);

if (removed.size() == 1)
{
Expand All @@ -523,22 +527,22 @@ class FeatureTransformer {
else
{
const IndexType offsetR1 = HalfDimensions * removed[1];
auto columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]);
auto* columnR1 = reinterpret_cast<const vec_t*>(&weights[offsetR1]);

for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(vec_t); ++i)
accOut[i] = vec_sub_16(vec_add_16(accIn[i], columnA[i]),
vec_add_16(columnR0[i], columnR1[i]));
}

auto accPsqtIn = reinterpret_cast<const psqt_vec_t*>(
auto* accPsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][0]);
auto accPsqtOut =
auto* accPsqtOut =
reinterpret_cast<psqt_vec_t*>(&(next->*accPtr).psqtAccumulation[Perspective][0]);

const IndexType offsetPsqtR0 = PSQTBuckets * removed[0];
auto columnPsqtR0 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR0]);
auto* columnPsqtR0 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR0]);
const IndexType offsetPsqtA = PSQTBuckets * added[0];
auto columnPsqtA = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtA]);
auto* columnPsqtA = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtA]);

if (removed.size() == 1)
{
Expand All @@ -550,7 +554,8 @@ class FeatureTransformer {
else
{
const IndexType offsetPsqtR1 = PSQTBuckets * removed[1];
auto columnPsqtR1 = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]);
auto* columnPsqtR1 =
reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offsetPsqtR1]);

for (std::size_t i = 0;
i < PSQTBuckets * sizeof(PSQTWeightType) / sizeof(psqt_vec_t); ++i)
Expand All @@ -561,69 +566,69 @@ class FeatureTransformer {
}
else
{
for (IndexType i = 0; i < HalfDimensions / TileHeight; ++i)
for (IndexType i = 0; i < HalfDimensions / Tiling::TileHeight; ++i)
{
// Load accumulator
auto accTileIn = reinterpret_cast<const vec_t*>(
&(computed->*accPtr).accumulation[Perspective][i * TileHeight]);
for (IndexType j = 0; j < NumRegs; ++j)
auto* accTileIn = reinterpret_cast<const vec_t*>(
&(computed->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_load(&accTileIn[j]);

// Difference calculation for the deactivated features
for (const auto index : removed)
{
const IndexType offset = HalfDimensions * index + i * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < NumRegs; ++j)
const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_sub_16(acc[j], column[j]);
}

// Difference calculation for the activated features
for (const auto index : added)
{
const IndexType offset = HalfDimensions * index + i * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < NumRegs; ++j)
const IndexType offset = HalfDimensions * index + i * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
acc[j] = vec_add_16(acc[j], column[j]);
}

// Store accumulator
auto accTileOut = reinterpret_cast<vec_t*>(
&(next->*accPtr).accumulation[Perspective][i * TileHeight]);
for (IndexType j = 0; j < NumRegs; ++j)
auto* accTileOut = reinterpret_cast<vec_t*>(
&(next->*accPtr).accumulation[Perspective][i * Tiling::TileHeight]);
for (IndexType j = 0; j < Tiling::NumRegs; ++j)
vec_store(&accTileOut[j], acc[j]);
}

for (IndexType i = 0; i < PSQTBuckets / PsqtTileHeight; ++i)
for (IndexType i = 0; i < PSQTBuckets / Tiling::PsqtTileHeight; ++i)
{
// Load accumulator
auto accTilePsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
auto* accTilePsqtIn = reinterpret_cast<const psqt_vec_t*>(
&(computed->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_load_psqt(&accTilePsqtIn[j]);

// Difference calculation for the deactivated features
for (const auto index : removed)
{
const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_sub_psqt_32(psqt[j], columnPsqt[j]);
}

// Difference calculation for the activated features
for (const auto index : added)
{
const IndexType offset = PSQTBuckets * index + i * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
const IndexType offset = PSQTBuckets * index + i * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
psqt[j] = vec_add_psqt_32(psqt[j], columnPsqt[j]);
}

// Store accumulator
auto accTilePsqtOut = reinterpret_cast<psqt_vec_t*>(
&(next->*accPtr).psqtAccumulation[Perspective][i * PsqtTileHeight]);
for (std::size_t j = 0; j < NumPsqtRegs; ++j)
auto* accTilePsqtOut = reinterpret_cast<psqt_vec_t*>(
&(next->*accPtr).psqtAccumulation[Perspective][i * Tiling::PsqtTileHeight]);
for (std::size_t j = 0; j < Tiling::NumPsqtRegs; ++j)
vec_store_psqt(&accTilePsqtOut[j], psqt[j]);
}
}
Expand Down Expand Up @@ -702,88 +707,88 @@ class FeatureTransformer {
accumulator.computed[Perspective] = true;

#ifdef VECTOR
vec_t acc[NumRegs];
psqt_vec_t psqt[NumPsqtRegs];
vec_t acc[Tiling::NumRegs];
psqt_vec_t psqt[Tiling::NumPsqtRegs];

for (IndexType j = 0; j < HalfDimensions / TileHeight; ++j)
for (IndexType j = 0; j < HalfDimensions / Tiling::TileHeight; ++j)
{
auto accTile =
reinterpret_cast<vec_t*>(&accumulator.accumulation[Perspective][j * TileHeight]);
auto entryTile = reinterpret_cast<vec_t*>(&entry.accumulation[j * TileHeight]);
auto* accTile = reinterpret_cast<vec_t*>(
&accumulator.accumulation[Perspective][j * Tiling::TileHeight]);
auto* entryTile = reinterpret_cast<vec_t*>(&entry.accumulation[j * Tiling::TileHeight]);

for (IndexType k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = entryTile[k];

int i = 0;
for (; i < int(std::min(removed.size(), added.size())); ++i)
std::size_t i = 0;
for (; i < std::min(removed.size(), added.size()); ++i)
{
IndexType indexR = removed[i];
const IndexType offsetR = HalfDimensions * indexR + j * TileHeight;
auto columnR = reinterpret_cast<const vec_t*>(&weights[offsetR]);
const IndexType offsetR = HalfDimensions * indexR + j * Tiling::TileHeight;
auto* columnR = reinterpret_cast<const vec_t*>(&weights[offsetR]);
IndexType indexA = added[i];
const IndexType offsetA = HalfDimensions * indexA + j * TileHeight;
auto columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);
const IndexType offsetA = HalfDimensions * indexA + j * Tiling::TileHeight;
auto* columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_add_16(acc[k], vec_sub_16(columnA[k], columnR[k]));
}
for (; i < int(removed.size()); ++i)
for (; i < removed.size(); ++i)
{
IndexType index = removed[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_sub_16(acc[k], column[k]);
}
for (; i < int(added.size()); ++i)
for (; i < added.size(); ++i)
{
IndexType index = added[i];
const IndexType offset = HalfDimensions * index + j * TileHeight;
auto column = reinterpret_cast<const vec_t*>(&weights[offset]);
const IndexType offset = HalfDimensions * index + j * Tiling::TileHeight;
auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);

for (unsigned k = 0; k < NumRegs; ++k)
for (IndexType k = 0; k < Tiling::NumRegs; ++k)
acc[k] = vec_add_16(acc[k], column[k]);
}

for (IndexType k = 0; k < NumRegs; k++)
for (IndexType k = 0; k < Tiling::NumRegs; k++)
vec_store(&entryTile[k], acc[k]);
for (IndexType k = 0; k < NumRegs; k++)
for (IndexType k = 0; k < Tiling::NumRegs; k++)
vec_store(&accTile[k], acc[k]);
}

for (IndexType j = 0; j < PSQTBuckets / PsqtTileHeight; ++j)
for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j)
{
auto accTilePsqt = reinterpret_cast<psqt_vec_t*>(
&accumulator.psqtAccumulation[Perspective][j * PsqtTileHeight]);
auto entryTilePsqt =
reinterpret_cast<psqt_vec_t*>(&entry.psqtAccumulation[j * PsqtTileHeight]);
auto* accTilePsqt = reinterpret_cast<psqt_vec_t*>(
&accumulator.psqtAccumulation[Perspective][j * Tiling::PsqtTileHeight]);
auto* entryTilePsqt =
reinterpret_cast<psqt_vec_t*>(&entry.psqtAccumulation[j * Tiling::PsqtTileHeight]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = entryTilePsqt[k];

for (int i = 0; i < int(removed.size()); ++i)
for (std::size_t i = 0; i < removed.size(); ++i)
{
IndexType index = removed[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);
}
for (int i = 0; i < int(added.size()); ++i)
for (std::size_t i = 0; i < added.size(); ++i)
{
IndexType index = added[i];
const IndexType offset = PSQTBuckets * index + j * PsqtTileHeight;
auto columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);
const IndexType offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;
auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(&psqtWeights[offset]);

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);
}

for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
vec_store_psqt(&entryTilePsqt[k], psqt[k]);
for (std::size_t k = 0; k < NumPsqtRegs; ++k)
for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)
vec_store_psqt(&accTilePsqt[k], psqt[k]);
}

Expand Down
Loading