diff --git a/CMakeLists.txt b/CMakeLists.txt index e14caf7f..66400bd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,11 +10,16 @@ FIND_PACKAGE(MAGMA) IF (NOT WIN32) SET(CMAKE_C_FLAGS "-std=c99 -Werror=implicit-function-declaration ${CMAKE_C_FLAGS}") ENDIF (NOT WIN32) -IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) +IF(CUDA_HAS_FP16 OR (NOT ${CUDA_VERSION} LESS 7.5 AND ${CUDA_VERSION} LESS 10.0)) ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1) ADD_DEFINITIONS(-DCUDA_HAS_FP16=1) ENDIF() +IF(MSVC) + # turn off CRT func warnings + ADD_DEFINITIONS(-D_CRT_SECURE_NO_WARNINGS) +ENDIF(MSVC) + INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS}) ADD_SUBDIRECTORY(lib) diff --git a/init.c b/init.c index 8b32a1a4..28f1302c 100644 --- a/init.c +++ b/init.c @@ -935,7 +935,11 @@ static int cutorch_isManagedPtr(lua_State *L) lua_pushboolean(L, 0); } else { THCudaCheck(res); +#if CUDART_VERSION >= 10000 + lua_pushboolean(L, (attributes.type == cudaMemoryTypeManaged) ? 1 : 0); +#else lua_pushboolean(L, attributes.isManaged); +#endif } return 1; } diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index 7a214396..63cf7d18 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -41,6 +41,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3") endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + IF(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.7" OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "4.7" ) SET(CXX_VERSION "c++11") @@ -203,14 +204,14 @@ endforeach() MESSAGE(STATUS "got cuda version " ${CUDA_VERSION}) -IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) +IF(CUDA_HAS_FP16 OR (NOT ${CUDA_VERSION} LESS 7.5 AND ${CUDA_VERSION} LESS 10.0)) MESSAGE(STATUS "Found CUDA with FP16 support, compiling with torch.CudaHalfTensor") LIST(APPEND src-cuda THCHalf.cu) - LIST(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1") + LIST(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__") SET(CMAKE_C_FLAGS "-DCUDA_HAS_FP16=1 ${CMAKE_C_FLAGS}") -ELSE(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) +ELSE(CUDA_HAS_FP16 OR (NOT ${CUDA_VERSION} LESS 7.5 AND ${CUDA_VERSION} LESS 10.0)) MESSAGE(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor") -ENDIF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) +ENDIF(CUDA_HAS_FP16 OR (NOT ${CUDA_VERSION} LESS 7.5 AND ${CUDA_VERSION} LESS 10.0)) MESSAGE(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}") IF ("$ENV{STATIC_TH}" STREQUAL "YES") diff --git a/lib/THC/THCApply.cuh b/lib/THC/THCApply.cuh index a47e3032..6d5ff96a 100644 --- a/lib/THC/THCApply.cuh +++ b/lib/THC/THCApply.cuh @@ -14,14 +14,20 @@ // Threads per block for our apply kernel // FIXME: use occupancy calculator instead -#define THC_APPLY_THREADS_PER_BLOCK 32 * 16 +#if __CUDA_ARCH__ >= 750 +#define THC_APPLY_THREADS_PER_BLOCK (32 * 16) +#define THC_APPLY_BLOCKS_PER_SM 2 +#else +#define THC_APPLY_THREADS_PER_BLOCK (32 * 16) +#define THC_APPLY_BLOCKS_PER_SM 4 +#endif template #if __CUDA_ARCH__ >= 350 -__launch_bounds__(32 * 16, 4) +__launch_bounds__(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) #endif __global__ void kernelPointwiseApply1(TensorInfo a, @@ -43,7 +49,7 @@ template #if __CUDA_ARCH__ >= 350 -__launch_bounds__(32 * 16, 4) +__launch_bounds__(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) #endif __global__ void kernelPointwiseApply2(TensorInfo a, @@ -70,7 +76,7 @@ template #if __CUDA_ARCH__ >= 350 -__launch_bounds__(32 * 16, 4) +__launch_bounds__(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) #endif __global__ void kernelPointwiseApply3(TensorInfo a, @@ -109,16 +115,16 @@ inline bool getApplyGrid(THCState* state, ptrdiff_t totalElements, dim3& grid) { return false; } - // Assume a reasonable number of SMs if no state is available - int numSM = - state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15; + if(THCState_getCurrentDeviceProperties(state)->major < 3){ + grid = dim3(min((long long) THCCeilDiv(totalElements, + (ptrdiff_t) THC_APPLY_THREADS_PER_BLOCK), (long long) 64*1024-1)); + return true; + } - // 16 warps per block * 4 per SM gives 64 warps per SM at maximum, - // which seems to be a good sweetspot for latency hiding - grid = dim3(min((long long) THCCeilDiv(totalElements, - (ptrdiff_t) THC_APPLY_THREADS_PER_BLOCK), - 4LL * numSM)); + grid = dim3((long long) THCCeilDiv(totalElements, + (ptrdiff_t) THC_APPLY_THREADS_PER_BLOCK) ); return true; + } template = 10000 && (__CUDA_ARCH__ >= 700 || !defined(__CUDA_ARCH__)) ) static inline __device__ void atomicAdd(half *address, half val) { unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); @@ -102,14 +103,22 @@ static inline __device__ void atomicAdd(half *address, half val) { do { assumed = old; +#if CUDA_VERSION < 9000 half hsum; hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); hsum = THCNumerics::add(hsum, val); +#else + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = THCNumerics::add(hsum, val); + hsum = __half_raw(tmpres); +#endif old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } #endif +#endif #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) // from CUDA C Programmic Guide diff --git a/lib/THC/THCBlas.cu b/lib/THC/THCBlas.cu index 9db4f0bf..00ece964 100644 --- a/lib/THC/THCBlas.cu +++ b/lib/THC/THCBlas.cu @@ -263,35 +263,45 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n, cublasOperation_t opb = convertTransToCublasOperation(transb); if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_k = (int)k; - int i_lda = (int)lda; - int i_ldb = (int)ldb; - int i_ldc = (int)ldc; + { + int i_m = (int)m; + int i_n = (int)n; + int i_k = (int)k; + int i_lda = (int)lda; + int i_ldb = (int)ldb; + int i_ldc = (int)ldc; - cublasHandle_t handle = THCState_getCurrentBlasHandle(state); - cublasSetStream(handle, THCState_getCurrentStream(state)); + cublasHandle_t handle = THCState_getCurrentBlasHandle(state); + cublasSetStream(handle, THCState_getCurrentStream(state)); - // Check for native Hgemm support - if (THC_fastHalfInstructions(state)) { - THCublasCheck(cublasHgemm(handle, opa, opb, - i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, - &beta, c, i_ldc)); - } else { // Simulated Hgemm float fAlpha = THC_half2float(alpha); float fBeta = THC_half2float(beta); +#if CUDA_VERSION < 9000 THCublasCheck(cublasSgemmEx(handle, opa, opb, - i_m, i_n, i_k, &fAlpha, + i_m, i_n, i_k, &fAlpha, a, CUDA_R_16F, i_lda, b, CUDA_R_16F, - i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); +#else + cudaDeviceProp* prop = THCState_getCurrentDeviceProperties(state); + if (prop->major >= 5){ + THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); + THCublasCheck(cublasGemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc, + CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP)); + THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); + }else{ + THCublasCheck(cublasSgemmEx(handle, opa, opb, + i_m, i_n, i_k, &fAlpha, + a, CUDA_R_16F, i_lda, b, CUDA_R_16F, + i_ldb, &fBeta, c, CUDA_R_16F, i_ldc)); + } +#endif + return; } - - return; - } THError("Cublas_Hgemm only supports m, n, k, lda, ldb, ldc" "with th bound [val] <= %d", INT_MAX); } diff --git a/lib/THC/THCDeviceUtils.cuh b/lib/THC/THCDeviceUtils.cuh index bd410554..896671c3 100644 --- a/lib/THC/THCDeviceUtils.cuh +++ b/lib/THC/THCDeviceUtils.cuh @@ -33,4 +33,72 @@ __device__ __forceinline__ T doLdg(const T* p) { #endif } +__device__ __forceinline__ unsigned int ACTIVE_MASK() +{ +#if CUDA_VERSION >= 9000 + return __activemask(); +#else +// will be ignored anyway + return 0xffffffff; +#endif +} + +__device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __ballot_sync(mask, predicate); +#else + return __ballot(predicate); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_sync(mask, value, srcLane, width); +#else + return __shfl(value, srcLane, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_up_sync(mask, value, delta, width); +#else + return __shfl_up(value, delta, width); +#endif +} + +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +__device__ __forceinline__ bool WARP_ANY(bool cond, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return (bool)__any_sync(mask, (int)cond); +#else + return __any(cond); +#endif +} + #endif // THC_DEVICE_UTILS_INC diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index e99487e9..43dcaba8 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -939,6 +939,7 @@ void THCHeapUpdate(THCState *state, ptrdiff_t size) { #include "THCAllocator.c" /* from THCHalf.h */ +#ifdef CUDA_HALF_TENSOR half THC_float2half(float f) { @@ -953,3 +954,6 @@ float THC_half2float(half h) TH_halfbits2float(&h.x, &f); return f; } + +#endif + diff --git a/lib/THC/THCHalf.h b/lib/THC/THCHalf.h index 7c055e7a..b376c241 100644 --- a/lib/THC/THCHalf.h +++ b/lib/THC/THCHalf.h @@ -4,7 +4,7 @@ #include "THCGeneral.h" /* We compile with CudaHalfTensor support if we have this: */ -#if CUDA_VERSION >= 7050 || CUDA_HAS_FP16 +#if (CUDA_VERSION >= 7050 && CUDA_VERSION < 10000) || CUDA_HAS_FP16 #define CUDA_HALF_TENSOR 1 #endif @@ -13,6 +13,12 @@ #include #include +#if CUDA_VERSION >= 9000 +#ifndef __cplusplus +typedef __half_raw half; +#endif +#endif + THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len); THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len); THC_API half THC_float2half(float a); diff --git a/lib/THC/THCNumerics.cuh b/lib/THC/THCNumerics.cuh index b6d1dac0..ba86e8f2 100644 --- a/lib/THC/THCNumerics.cuh +++ b/lib/THC/THCNumerics.cuh @@ -111,8 +111,13 @@ struct THCNumerics { #ifdef CUDA_HALF_TENSOR template <> struct THCNumerics { +#if CUDA_VERSION < 9000 static inline __host__ __device__ half min() { half h; h.x = 0xfbff; return h; } static inline __host__ __device__ half max() { half h; h.x = 0x7bff; return h; } +#else + static inline __host__ __device__ half min() { __half_raw h; h.x = 0xfbff; return h; } + static inline __host__ __device__ half max() { __half_raw h; h.x = 0x7bff; return h; } +#endif static inline __host__ __device__ bool lt(half a, half b) { #ifdef __CUDA_ARCH__ diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh index 067d796a..869de795 100644 --- a/lib/THC/THCReduce.cuh +++ b/lib/THC/THCReduce.cuh @@ -12,7 +12,13 @@ #include "THCReduceApplyUtils.cuh" // Threads per thread block -#define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16 +#if __CUDA_ARCH__ >= 750 +#define THC_NONCONTIG_REDUCE_BLOCK_SIZE (32 * 16) +#define THC_NONCONTIG_REDUCE_BLOCKS_PER_SM 2 +#else +#define THC_NONCONTIG_REDUCE_BLOCK_SIZE (32 * 16) +#define THC_NONCONTIG_REDUCE_BLOCKS_PER_SM 4 +#endif template __device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() { @@ -27,7 +33,7 @@ template #if __CUDA_ARCH__ >= 350 -__launch_bounds__(32 * 16, 4) +__launch_bounds__(THC_NONCONTIG_REDUCE_BLOCK_SIZE, THC_NONCONTIG_REDUCE_BLOCKS_PER_SM) #endif __global__ void kernelReduceNoncontigDim(TensorInfo out, @@ -324,5 +330,6 @@ bool THC_reduceDim(THCState* state, } #undef THC_NONCONTIG_REDUCE_BLOCK_SIZE +#undef THC_NONCONTIG_REDUCE_BLOCKS_PER_SM #endif // THC_REDUCE_INC diff --git a/lib/THC/THCScanUtils.cuh b/lib/THC/THCScanUtils.cuh index ccf27b79..9a487ca7 100644 --- a/lib/THC/THCScanUtils.cuh +++ b/lib/THC/THCScanUtils.cuh @@ -2,6 +2,7 @@ #define THC_SCAN_UTILS_INC #include "THCAsmUtils.cuh" +#include "THCDeviceUtils.cuh" // Collection of in-kernel scan / prefix sum utilities @@ -152,7 +153,7 @@ __device__ void exclusivePrefixScan(T* smem, T in, T* out, T* carry, BinaryFunct template __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) { // Within-warp, we use warp voting. - T vote = __ballot(in); + T vote = WARP_BALLOT(in); T index = __popc(getLaneMaskLe() & vote); T carry = __popc(vote); diff --git a/lib/THC/THCTensorMath.cuh b/lib/THC/THCTensorMath.cuh index ae8f5db3..202090e3 100644 --- a/lib/THC/THCTensorMath.cuh +++ b/lib/THC/THCTensorMath.cuh @@ -26,6 +26,24 @@ __global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t #define CAT_ARRAY_BATCH_SIZE 1024 #define CAT_ARRAY_MAX_INPUT_DIMS 4 +inline bool getCatGrid(THCState* state, ptrdiff_t nTensors, dim3& grid) { + int curDevice = -1; + cudaGetDevice(&curDevice); + + if (curDevice == -1) { + return false; + } + + // Assume a reasonable number of SMs if no state is available + int numSM = + state ? THCState_getCurrentDeviceProperties(state)->multiProcessorCount : 15; + //X dim of grid for cat array cooperates on a single tensor in the cat. + //Given half of the GPU, full utilization will always occur. + grid = dim3( 2LL * numSM, (long long) nTensors ); + + return true; +} + // Similar to any other IndexToOffset calculation for copying along a given dimension. template struct CatArrIndexToOffset { @@ -77,6 +95,9 @@ struct OutputTensorSizeStride { * * The most important assumption made is that the input tensors are contiguous. */ + + + template __global__ void CatArrayBatchedCopy( T* output, @@ -84,19 +105,26 @@ __global__ void CatArrayBatchedCopy( OutputTensorSizeStride os, const int concatDim, IndexType dimStride) { - T* data = inputs[blockIdx.y].input; - IndexType offset = inputs[blockIdx.y].offset; - IndexType dimSize = inputs[blockIdx.y].dimSize; - IndexType nElements = inputs[blockIdx.y].nElements; - IndexType dataOffset = offset * dimStride; - - for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < nElements; - linearIndex += gridDim.x * blockDim.x) { + + IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; + IndexType nElements = inputs[blockIdx.y].nElements; + + if(tid >= nElements) return; + + T* data = inputs[blockIdx.y].input; + IndexType offset = inputs[blockIdx.y].offset; + IndexType dimSize = inputs[blockIdx.y].dimSize; + IndexType dataOffset = offset * dimStride; + + IndexType stride = gridDim.x * blockDim.x; + + while( tid < nElements){ IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, linearIndex); - output[dataOffset + elementOffset] = data[linearIndex]; - } + os.outputSize, os.outputStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + + tid += stride; + } } #endif diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index b67ac2a4..3544e06f 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -5,6 +5,7 @@ #include "THCSortUtils.cuh" #include "THCScanUtils.cuh" +#ifdef CUDA_HALF_TENSOR struct ThrustHalfLess { __host__ __device__ inline bool operator()(const half& lhs, const half& rhs) { @@ -35,6 +36,7 @@ struct ThrustHalfEqualToPredicate half val_; }; +#endif template struct BinaryAddOp { diff --git a/lib/THC/THCTensorRandom.cuh b/lib/THC/THCTensorRandom.cuh index 5afd8fed..65ed7834 100644 --- a/lib/THC/THCTensorRandom.cuh +++ b/lib/THC/THCTensorRandom.cuh @@ -43,7 +43,7 @@ __global__ void generateLogNormal(curandStateMtgp32 *state, int size, do // Normalizes the L1 norm of every row to 1; used by multinomial template __global__ void renormRowsL1(T* dist, long rows, long cols) { - extern __shared__ __align__(sizeof(T)) unsigned char my_smem[]; + extern __shared__ unsigned char my_smem[]; T *smem = reinterpret_cast(my_smem); for (long row = blockIdx.x; row < rows; row += gridDim.x) { @@ -104,7 +104,7 @@ sampleMultinomialOnce(long* dest, int categories, T* sampled, T* dist) { - extern __shared__ __align__(sizeof(AccT)) unsigned char my_smem[]; + extern __shared__ unsigned char my_smem[]; __shared__ bool found; // Shared Memory hold blockdim.x T for holding the cumulative sum, diff --git a/lib/THC/THCTensorTopK.cuh b/lib/THC/THCTensorTopK.cuh index 7269e991..c98ff5cd 100644 --- a/lib/THC/THCTensorTopK.cuh +++ b/lib/THC/THCTensorTopK.cuh @@ -118,7 +118,7 @@ struct TopKTypeConfig { typedef unsigned int RadixType; static inline __device__ RadixType convert(half v) { -#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 +#if CUDA_VERSION >= 8000 RadixType x = __half_as_ushort(v); RadixType mask = -((x >> 15)) | 0x8000; return (x ^ mask); @@ -129,7 +129,7 @@ struct TopKTypeConfig { } static inline __device__ half deconvert(RadixType v) { -#if defined(__CUDACC_VER__) && __CUDACC_VER__ >= 80000 +#if CUDA_VERSION >= 8000 RadixType mask = ((v >> 15) - 1) | 0x8000; return __ushort_as_half(v ^ mask); #else @@ -178,7 +178,7 @@ __device__ void countRadixUsingMask(CountType counts[RadixSize], #pragma unroll for (unsigned int j = 0; j < RadixSize; ++j) { bool vote = hasVal && (digitInRadix == j); - counts[j] += __popc(__ballot(vote)); + counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK())); } } diff --git a/lib/THC/THCTensorTypeUtils.cu b/lib/THC/THCTensorTypeUtils.cu index e4c1c34f..431e38e5 100644 --- a/lib/THC/THCTensorTypeUtils.cu +++ b/lib/THC/THCTensorTypeUtils.cu @@ -117,7 +117,7 @@ TensorUtils::getDims(THCState* state, \ bool \ TensorUtils::isContiguous(THCState* state, \ TENSOR_TYPE* t) { \ - return TENSOR_TYPE##_isContiguous(state, t); \ + return (TENSOR_TYPE##_isContiguous(state, t) != 0); \ } \ \ bool \ diff --git a/lib/THC/THCTensorTypeUtils.cuh b/lib/THC/THCTensorTypeUtils.cuh index 37edb763..16a0cdee 100644 --- a/lib/THC/THCTensorTypeUtils.cuh +++ b/lib/THC/THCTensorTypeUtils.cuh @@ -149,7 +149,11 @@ struct ScalarNegate { return __float2half(-__half2float(v)); #endif #else +#if CUDA_VERSION < 9000 half out = v; +#else + __half_raw out = __half_raw(v); +#endif out.x ^= 0x8000; // toggle sign bit return out; #endif @@ -170,11 +174,25 @@ struct ScalarInv { }; inline bool operator==(half a, half b) { +#if CUDA_VERSION < 9000 return a.x == b.x; +#else + __half_raw araw, braw; + araw = __half_raw(a); + braw = __half_raw(b); + return araw.x == braw.x; +#endif } inline bool operator!=(half a, half b) { - return a.x != b.x; +#if CUDA_VERSION < 9000 + return a.x != b.x; +#else + __half_raw araw, braw; + araw = __half_raw(a); + braw = __half_raw(b); + return araw.x != braw.x; +#endif } #endif // CUDA_HALF_TENSOR diff --git a/lib/THC/cmake/select_compute_arch.cmake b/lib/THC/cmake/select_compute_arch.cmake index 4b274418..2075202b 100644 --- a/lib/THC/cmake/select_compute_arch.cmake +++ b/lib/THC/cmake/select_compute_arch.cmake @@ -5,9 +5,9 @@ # - "Auto" detects local machine GPU compute arch at runtime. # - "Common" and "All" cover common and entire subsets of architectures # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX -# NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal +# NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing # NUM: Any number. Only those pairs are currently accepted by NVCC though: -# 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 +# 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} # Additionally, sets ${out_variable}_readable to the resulting numeric list # Example: @@ -17,25 +17,69 @@ # More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA # +if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA" + AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)") + set(CUDA_VERSION "${CMAKE_MATCH_1}") + endif() +endif() + +# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list + # This list will be used for CUDA_ARCH_NAME = All option set(CUDA_KNOWN_GPU_ARCHITECTURES "Fermi" "Kepler" "Maxwell") # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) set(CUDA_COMMON_GPU_ARCHITECTURES "3.0" "3.5" "5.0") -if (CUDA_VERSION VERSION_GREATER "6.5") +if(CUDA_VERSION VERSION_LESS "7.0") + set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2") +endif() + +# This list is used to filter CUDA archs when autodetecting +set(CUDA_ALL_GPU_ARCHITECTURES "3.0" "3.2" "3.5" "5.0") + +if(CUDA_VERSION VERSION_GREATER_EQUAL "7.0") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra") list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2") -endif () -if (CUDA_VERSION VERSION_GREATER "7.5") + if(CUDA_VERSION VERSION_LESS "8.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") + set(CUDA_LIMIT_GPU_ARCHITECTURE "6.0") + endif() +endif() + +if(CUDA_VERSION VERSION_GREATER_EQUAL "8.0") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1" "6.1+PTX") -else() - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "6.0" "6.1" "6.2") + + if(CUDA_VERSION VERSION_LESS "9.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.1+PTX") + set(CUDA_LIMIT_GPU_ARCHITECTURE "7.0") + endif() endif () +if(CUDA_VERSION VERSION_GREATER_EQUAL "9.0") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Volta") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0") + + if(CUDA_VERSION VERSION_LESS "10.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0+PTX") + set(CUDA_LIMIT_GPU_ARCHITECTURE "7.2") + endif() +endif() +if(CUDA_VERSION VERSION_GREATER_EQUAL "10.0") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Turing") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.2" "7.5" "7.5+PTX") + + if(CUDA_VERSION VERSION_LESS "11.0") + set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0") + endif() +endif() ################################################################################################ # A function for automatic detection of GPUs installed (if autodetection is enabled) @@ -44,9 +88,14 @@ endif () # function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) if(NOT CUDA_GPU_DETECT_OUTPUT) - set(cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) + if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu") + else() + set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp") + endif() - file(WRITE ${cufile} "" + file(WRITE ${file} "" + "#include \n" "#include \n" "int main()\n" "{\n" @@ -62,19 +111,23 @@ function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) " return 0;\n" "}\n") - execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${cufile}" - "-ccbin" ${CMAKE_CXX_COMPILER} - WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" - RESULT_VARIABLE nvcc_res OUTPUT_VARIABLE nvcc_out - ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) - - if(nvcc_res EQUAL 0) - # only keep the last line of nvcc_out - STRING(REGEX REPLACE ";" "\\\\;" nvcc_out "${nvcc_out}") - STRING(REGEX REPLACE "\n" ";" nvcc_out "${nvcc_out}") - list(GET nvcc_out -1 nvcc_out) - string(REPLACE "2.1" "2.1(2.0)" nvcc_out "${nvcc_out}") - set(CUDA_GPU_DETECT_OUTPUT ${nvcc_out} CACHE INTERNAL "Returned GPU architetures from detect_gpus tool" FORCE) + if(CMAKE_CUDA_COMPILER_LOADED) # CUDA as a language + try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE compute_capabilities) + else() + try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}" + LINK_LIBRARIES ${CUDA_LIBRARIES} + RUN_OUTPUT_VARIABLE compute_capabilities) + endif() + + # Filter unrelated content out of the output. + string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}") + + if(run_result EQUAL 0) + string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}") + set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities} + CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE) endif() endif() @@ -82,7 +135,19 @@ function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) message(STATUS "Automatic GPU detection failed. Building for common architectures.") set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) else() - set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT} PARENT_SCOPE) + # Filter based on CUDA version supported archs + set(CUDA_GPU_DETECT_OUTPUT_FILTERED "") + separate_arguments(CUDA_GPU_DETECT_OUTPUT) + foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT}) + if(CUDA_LIMIT_GPU_ARCHITECTURE AND ITEM VERSION_GREATER_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE) + list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM) + string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}") + else() + string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}") + endif() + endforeach() + + set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE) endif() endfunction() @@ -115,19 +180,20 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) list(REMOVE_DUPLICATES CUDA_ARCH_LIST) foreach(arch_name ${CUDA_ARCH_LIST}) set(arch_bin) + set(arch_ptx) set(add_ptx FALSE) # Check to see if we are compiling PTX if(arch_name MATCHES "(.*)\\+PTX$") set(add_ptx TRUE) set(arch_name ${CMAKE_MATCH_1}) endif() - if(arch_name MATCHES "(^[0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") + if(arch_name MATCHES "^([0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") set(arch_bin ${CMAKE_MATCH_1}) set(arch_ptx ${arch_bin}) else() # Look for it in our list of known architectures if(${arch_name} STREQUAL "Fermi") - set(arch_bin "2.0 2.1(2.0)") + set(arch_bin 2.0 "2.1(2.0)") elseif(${arch_name} STREQUAL "Kepler+Tegra") set(arch_bin 3.2) elseif(${arch_name} STREQUAL "Kepler+Tesla") @@ -143,6 +209,17 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) elseif(${arch_name} STREQUAL "Pascal") set(arch_bin 6.0 6.1) set(arch_ptx 6.1) + elseif(${arch_name} STREQUAL "Volta") + if(CUDA_VERSION VERSION_GREATER_EQUAL "10.0") + set(arch_bin 7.0 7.2) + set(arch_ptx 7.2) + else() + set(arch_bin 7.0) + set(arch_ptx 7.0) + endif() + elseif(${arch_name} STREQUAL "Turing") + set(arch_bin 7.5) + set(arch_ptx 7.5) else() message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") endif() diff --git a/lib/THC/generic/THCTensorMath.cu b/lib/THC/generic/THCTensorMath.cu index 0eed5a9a..f9e81332 100644 --- a/lib/THC/generic/THCTensorMath.cu +++ b/lib/THC/generic/THCTensorMath.cu @@ -191,7 +191,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - CatArrayBatchedCopy<<stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); + CatArrayBatchedCopy<<stream>>>(data, d_inputs, param, cat_dimension, param.outputStride[cat_dimension]); // Now we loop offset = 0; @@ -227,15 +227,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, // is based on. dim3 applyBlock = getApplyBlock(); - // We also re-use the applyGrid - but note that we use the maximum number of - // elements for a given tensor in this grouping to determine the count - dim3 applyGrid; - getApplyGrid(state, cohortMax, applyGrid); + //Get grid where x dim fills half gpu and y dim is number of tensors. + //This will have cating two tensors fill the entire grid, but prevent + //many threads from needlessly load meta data if their sizes is small. + dim3 catGrid; + getCatGrid(state, j, catGrid); - // Next, we set our grid's y component to be the number of tensors in - // the batch. This will allow the kernel to determine which input - // tensor it is responsible for copying - applyGrid.y = j; switch (maxDim) { case 1: diff --git a/lib/THC/generic/THCTensorSort.cu b/lib/THC/generic/THCTensorSort.cu index 067af89e..a4408b43 100644 --- a/lib/THC/generic/THCTensorSort.cu +++ b/lib/THC/generic/THCTensorSort.cu @@ -8,7 +8,7 @@ THC_API void THCTensor_(sortKeyValueInplace)(THCState* state, THCTensor* key, THCudaLongTensor* value, - int dim, bool dir) { + int dim, int dir) { THLongStorage *valueSize = THCudaLongTensor_newSizeOf(state, value); THArgCheck(THCTensor_(isSize)(state, key, valueSize), 2, "Key tensor must have same size as value tensor"); @@ -157,7 +157,7 @@ void sortViaThrust(THCState* state, THCTensor* sorted, THCudaLongTensor* indices, THCTensor* input, - int dim, bool dir) { + int dim, int dir) { long nDims = THCTensor_(nDimension)(state, input); ptrdiff_t totalElements = THCTensor_(nElement)(state, input); @@ -327,7 +327,7 @@ THC_API void THCTensor_(sort)(THCState* state, } else { // Otherwise, fall back upon Thrust, which handles all other cases // (potentially slowly, with extra copies/memory allocations) - sortViaThrust(state, sorted, indices, input, dim, (bool) order); + sortViaThrust(state, sorted, indices, input, dim, order); } THCudaCheck(cudaGetLastError());