Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

builtin: update inline spirv hlsl #739

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions include/nbl/builtin/hlsl/glsl_compat/core.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,21 @@ SquareMatrix inverse(NBL_CONST_REF_ARG(SquareMatrix) mat)
*/
// TODO: Extemely annoying that HLSL doesn't have references, so we can't transparently alias the variables as `&` :(
//void gl_Position() {spirv::}
uint32_t gl_VertexIndex() {return spirv::VertexIndex;}
uint32_t gl_InstanceIndex() {return spirv::InstanceIndex;}
uint32_t gl_VertexIndex() {return spirv::builtin::VertexIndex;}
uint32_t gl_InstanceIndex() {return spirv::builtin::InstanceIndex;}

/**
* For Compute Shaders
*/

// TODO: Extemely annoying that HLSL doesn't have references, so we can't transparently alias the variables as `const&` :(
uint32_t3 gl_NumWorkGroups() {return spirv::NumWorkGroups;}
uint32_t3 gl_NumWorkGroups() {return spirv::builtin::NumWorkgroups;}
// TODO: DXC BUG prevents us from defining this!
uint32_t3 gl_WorkGroupSize();
uint32_t3 gl_WorkGroupID() {return spirv::WorkgroupId;}
uint32_t3 gl_LocalInvocationID() {return spirv::LocalInvocationId;}
uint32_t3 gl_GlobalInvocationID() {return spirv::GlobalInvocationId;}
uint32_t gl_LocalInvocationIndex() {return spirv::LocalInvocationIndex;}
uint32_t3 gl_WorkGroupID() {return spirv::builtin::WorkgroupId;}
uint32_t3 gl_LocalInvocationID() {return spirv::builtin::LocalInvocationId;}
uint32_t3 gl_GlobalInvocationID() {return spirv::builtin::GlobalInvocationId;}
uint32_t gl_LocalInvocationIndex() {return spirv::builtin::LocalInvocationIndex;}

void barrier() {
spirv::controlBarrier(spv::ScopeWorkgroup, spv::ScopeWorkgroup, spv::MemorySemanticsAcquireReleaseMask | spv::MemorySemanticsWorkgroupMemoryMask);
Expand Down Expand Up @@ -187,7 +187,7 @@ struct bitfieldExtract<T, true, true>
{
static T __call( T val, uint32_t offsetBits, uint32_t numBits )
{
return spirv::bitFieldSExtract<T>( val, offsetBits, numBits );
return spirv::bitFieldSExtract( val, offsetBits, numBits );
}
};

Expand All @@ -196,7 +196,7 @@ struct bitfieldExtract<T, false, true>
{
static T __call( T val, uint32_t offsetBits, uint32_t numBits )
{
return spirv::bitFieldUExtract<T>( val, offsetBits, numBits );
return spirv::bitFieldUExtract( val, offsetBits, numBits );
}
};

Expand Down
187 changes: 159 additions & 28 deletions include/nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifndef _NBL_BUILTIN_HLSL_GLSL_COMPAT_SUBGROUP_ARITHMETIC_INCLUDED_
#define _NBL_BUILTIN_HLSL_GLSL_COMPAT_SUBGROUP_ARITHMETIC_INCLUDED_

#include "nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl"
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"

namespace nbl
{
Expand All @@ -17,93 +17,224 @@ namespace glsl
// TODO: Furthermore you'll need `bitfieldExtract`-like struct dispatcher to choose between int/float add/mul and sint/uint/float min/max
template<typename T>
T subgroupAdd(T value) {
return spirv::groupAdd(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformIAdd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}
template<typename T>
T subgroupInclusiveAdd(T value) {
return spirv::groupAdd(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformIAdd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}
template<typename T>
T subgroupExclusiveAdd(T value) {
return spirv::groupAdd(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformIAdd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

template<typename T>
T subgroupMul(T value) {
return spirv::groupMul(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformIMul_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}
template<typename T>
T subgroupInclusiveMul(T value) {
return spirv::groupMul(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformIMul_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}
template<typename T>
T subgroupExclusiveMul(T value) {
return spirv::groupMul(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformIMul_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

template<typename T>
T subgroupAnd(T value) {
return spirv::groupBitwiseAnd(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformBitwiseAnd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}
template<typename T>
T subgroupInclusiveAnd(T value) {
return spirv::groupBitwiseAnd(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformBitwiseAnd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}
template<typename T>
T subgroupExclusiveAnd(T value) {
return spirv::groupBitwiseAnd(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformBitwiseAnd_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

template<typename T>
T subgroupOr(T value) {
return spirv::groupBitwiseOr(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformBitwiseOr_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}
template<typename T>
T subgroupInclusiveOr(T value) {
return spirv::groupBitwiseOr(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformBitwiseOr_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}
template<typename T>
T subgroupExclusiveOr(T value) {
return spirv::groupBitwiseOr(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformBitwiseOr_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

template<typename T>
T subgroupXor(T value) {
return spirv::groupBitwiseXor(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformBitwiseXor_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}
template<typename T>
T subgroupInclusiveXor(T value) {
return spirv::groupBitwiseXor(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformBitwiseXor_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}
template<typename T>
T subgroupExclusiveXor(T value) {
return spirv::groupBitwiseXor(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformBitwiseXor_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

namespace impl
{

template<typename T, bool isSigned>
struct subgroupMin {};

template<typename T>
struct subgroupMin<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, val);
}
};

template<typename T>
struct subgroupMin<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, val);
}
};

template<typename T, bool isSigned>
struct subgroupInclusiveMin {};

template<typename T>
struct subgroupInclusiveMin<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, val);
}
};

template<typename T>
struct subgroupInclusiveMin<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, val);
}
};

template<typename T, bool isSigned>
struct subgroupExclusiveMin {};

template<typename T>
struct subgroupExclusiveMin<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, val);
}
};

template<typename T>
struct subgroupExclusiveMin<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMin_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, val);
}
};

template<typename T, bool isSigned>
struct subgroupMax {};

template<typename T>
struct subgroupMax<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, val);
}
};

template<typename T>
struct subgroupMax<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationReduce, val);
}
};

template<typename T, bool isSigned>
struct subgroupInclusiveMax {};

template<typename T>
struct subgroupInclusiveMax<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, val);
}
};

template<typename T>
struct subgroupInclusiveMax<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, val);
}
};

template<typename T, bool isSigned>
struct subgroupExclusiveMax {};

template<typename T>
struct subgroupExclusiveMax<T, true>
{
static T __call(T val)
{
return spirv::groupNonUniformSMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, val);
}
};

template<typename T>
struct subgroupExclusiveMax<T, false>
{
static T __call(T val)
{
return spirv::groupNonUniformUMax_GroupNonUniformArithmetic<T>(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, val);
}
};

}

template<typename T>
T subgroupMin(T value) {
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
T subgroupMin(T val) {
return impl::subgroupMin<T, is_signed<T>::value>::__call(val);
}
template<typename T>
T subgroupInclusiveMin(T value) {
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
T subgroupInclusiveMin(T val) {
return impl::subgroupInclusiveMin<T, is_signed<T>::value>::__call(val);
}
template<typename T>
T subgroupExclusiveMin(T value) {
return spirv::groupBitwiseMin(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
T subgroupExclusiveMin(T val) {
return impl::subgroupExclusiveMin<T, is_signed<T>::value>::__call(val);
}

template<typename T>
T subgroupMax(T value) {
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
T subgroupMax(T val) {
return impl::subgroupMax<T, is_signed<T>::value>::__call(val);
}
template<typename T>
T subgroupInclusiveMax(T value) {
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
T subgroupInclusiveMax(T val) {
return impl::subgroupInclusiveMax<T, is_signed<T>::value>::__call(val);
}
template<typename T>
T subgroupExclusiveMax(T value) {
return spirv::groupBitwiseMax(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
T subgroupExclusiveMax(T val) {
return impl::subgroupExclusiveMax<T, is_signed<T>::value>::__call(val);
}

}
Expand Down
32 changes: 16 additions & 16 deletions include/nbl/builtin/hlsl/glsl_compat/subgroup_ballot.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifndef _NBL_BUILTIN_HLSL_GLSL_COMPAT_SUBGROUP_BALLOT_INCLUDED_
#define _NBL_BUILTIN_HLSL_GLSL_COMPAT_SUBGROUP_BALLOT_INCLUDED_

#include "nbl/builtin/hlsl/spirv_intrinsics/subgroup_ballot.hlsl"
#include "nbl/builtin/hlsl/spirv_intrinsics/core.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"

namespace nbl
Expand All @@ -15,62 +15,62 @@ namespace glsl
{

// TODO: Extemely annoying that HLSL doesn't have referencies, so we can't transparently alias the variables as `const&` :(
uint32_t4 gl_SubgroupEqMask() {return spirv::BuiltInSubgroupEqMask;}
uint32_t4 gl_SubgroupGeMask() {return spirv::BuiltInSubgroupGeMask;}
uint32_t4 gl_SubgroupGtMask() {return spirv::BuiltInSubgroupGtMask;}
uint32_t4 gl_SubgroupLeMask() {return spirv::BuiltInSubgroupLeMask;}
uint32_t4 gl_SubgroupLtMask() {return spirv::BuiltInSubgroupLtMask;}
uint32_t4 gl_SubgroupEqMask() {return spirv::builtin::SubgroupEqMask;}
uint32_t4 gl_SubgroupGeMask() {return spirv::builtin::SubgroupGeMask;}
uint32_t4 gl_SubgroupGtMask() {return spirv::builtin::SubgroupGtMask;}
uint32_t4 gl_SubgroupLeMask() {return spirv::builtin::SubgroupLeMask;}
uint32_t4 gl_SubgroupLtMask() {return spirv::builtin::SubgroupLtMask;}

template<typename T>
T subgroupBroadcastFirst(T value)
{
return spirv::subgroupBroadcastFirst<T>(spv::ScopeSubgroup, value);
return spirv::groupNonUniformBroadcastFirst<T>(spv::ScopeSubgroup, value);
}

template<typename T>
T subgroupBroadcast(T value, const uint32_t invocationId)
{
return spirv::subgroupBroadcast<T>(spv::ScopeSubgroup, value, invocationId);
return spirv::groupNonUniformBroadcast<T>(spv::ScopeSubgroup, value, invocationId);
}

uint32_t4 subgroupBallot(bool value)
{
return spirv::subgroupBallot(spv::ScopeSubgroup, value);
return spirv::groupNonUniformBallot(spv::ScopeSubgroup, value);
}

bool subgroupInverseBallot(uint32_t4 value)
{
return spirv::subgroupInverseBallot(spv::ScopeSubgroup, value);
return spirv::groupNonUniformInverseBallot(spv::ScopeSubgroup, value);
}

bool subgroupBallotBitExtract(uint32_t4 value, uint32_t index)
{
return spirv::subgroupBallotBitExtract(spv::ScopeSubgroup, value, index);
return spirv::groupNonUniformBallotBitExtract(spv::ScopeSubgroup, value, index);
}

uint32_t subgroupBallotBitCount(uint32_t4 value)
{
return spirv::subgroupBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
return spirv::groupNonUniformBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationReduce, value);
}

uint32_t subgroupBallotInclusiveBitCount(uint32_t4 value)
{
return spirv::subgroupBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
return spirv::groupNonUniformBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationInclusiveScan, value);
}

uint32_t subgroupBallotExclusiveBitCount(uint32_t4 value)
{
return spirv::subgroupBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
return spirv::groupNonUniformBallotBitCount(spv::ScopeSubgroup, spv::GroupOperationExclusiveScan, value);
}

uint32_t subgroupBallotFindLSB(uint32_t4 value)
{
return spirv::subgroupBallotFindLSB(spv::ScopeSubgroup, value);
return spirv::groupNonUniformBallotFindLSB(spv::ScopeSubgroup, value);
}

uint32_t subgroupBallotFindMSB(uint32_t4 value)
{
return spirv::subgroupBallotFindMSB(spv::ScopeSubgroup, value);
return spirv::groupNonUniformBallotFindMSB(spv::ScopeSubgroup, value);
}
}
}
Expand Down
Loading