diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 0cb39e7927cd6..5da439e94e092 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -42,6 +42,16 @@ The following release is verified with good quality: ## News +- 2025.2 + - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). + |GPU|Base tokens/s|Increased tokens/s|Percent| + |-|-|-|-| + |PVC 1550|39|73|+87%| + |Flex 170|39|50|+28%| + |Arc770|42|55|+30%| + |MTL|13|16|+23%| + |ARL-H|14|17|+21%| + - 2024.11 - Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer. @@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family: | Intel Data Center Max Series | Support | Max 1550, 1100 | | Intel Data Center Flex Series | Support | Flex 170 | | Intel Arc Series | Support | Arc 770, 730M, Arc A750 | -| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake | -| Intel iGPU | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 | +| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake | +| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 | *Notes:* @@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | Name | Value | Function | |-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------| | GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG | +| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | + ## Known Issues - `Split-mode:[row]` is not supported. diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 3b9ba3b2da491..8ce71e37000a1 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -3,7 +3,7 @@ # MIT license # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: MIT - +export ONEAPI_DEVICE_SELECTOR="level_zero:0" source /opt/intel/oneapi/setvars.sh #export GGML_SYCL_DEBUG=1 @@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" MODEL_FILE=models/llama-2-7b.Q4_0.gguf NGL=33 -CONEXT=8192 +CONEXT=4096 if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 3579a311aac07..3ad044432a27d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -1,3 +1,5 @@ +message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") + if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") endif() diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 022e7b7637bd3..9069c47865f68 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -99,3 +99,20 @@ catch (sycl::exception const &exc) { << ", line:" << __LINE__ << std::endl; std::exit(1); } + + +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + if (extra->events[i][is] != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is]))); + } + } + if (extra->data_device[i] != nullptr && streams.size()>0) { + ggml_sycl_set_device(i); + SYCL_CHECK( + CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + } + } + delete extra; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index a5cab5065fc80..7c503a1b10e4a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,6 +19,9 @@ #include "dpct/helper.hpp" #include "ggml-sycl.h" #include "presets.hpp" +#include "sycl_hw.hpp" + + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" @@ -35,7 +38,10 @@ void* ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void* ptr); + extern int g_ggml_sycl_debug; +extern int g_ggml_sycl_disable_optimize; + #define GGML_SYCL_DEBUG(...) \ do { \ if (g_ggml_sycl_debug) \ @@ -182,18 +188,24 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try { } ////////////////////// +struct optimize_feature { + bool reorder=false; +}; + +struct sycl_device_info { + int cc; // compute capability + // int nsm; // number of streaming multiprocessors + // size_t smpb; // max. shared memory per block + bool vmm; // virtual memory support + size_t total_vram; + sycl_hw_info hw_info; + optimize_feature opt_feature; +}; + struct ggml_sycl_device_info { int device_count; - struct sycl_device_info { - int cc; // compute capability - // int nsm; // number of streaming multiprocessors - // size_t smpb; // max. shared memory per block - bool vmm; // virtual memory support - size_t total_vram; - }; - sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; std::array default_tensor_split = {}; @@ -260,17 +272,46 @@ struct ggml_tensor_extra_gpu { // tensors dpct::event_ptr events[GGML_SYCL_MAX_DEVICES] [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs + optimize_feature optimized_feature; }; +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); + +inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) { + optimize_feature opt; + + opt.reorder = + (arch == syclex::architecture::intel_gpu_dg1 || + arch == syclex::architecture::intel_gpu_acm_g10 || + arch == syclex::architecture::intel_gpu_acm_g11 || + arch == syclex::architecture::intel_gpu_acm_g12 || + arch == syclex::architecture::intel_gpu_pvc || + arch == syclex::architecture::intel_gpu_pvc_vg || + arch == syclex::architecture::intel_gpu_mtl_u || + arch == syclex::architecture::intel_gpu_mtl_s || + arch == syclex::architecture::intel_gpu_mtl_h || + arch == syclex::architecture::intel_gpu_arl_u || + arch == syclex::architecture::intel_gpu_arl_s || + arch == syclex::architecture::intel_gpu_arl_h || + arch == syclex::architecture::intel_gpu_bmg_g21 || + arch == syclex::architecture::intel_gpu_lnl_m + ); + + return opt; +} + struct ggml_backend_sycl_context { int device; std::string name; + optimize_feature opt_feature; + bool optimized_graph=false; queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; explicit ggml_backend_sycl_context(int device) : device(device), name(GGML_SYCL_NAME + std::to_string(device)) { + opt_feature = ggml_sycl_info().devices[device].opt_feature; } queue_ptr stream(int device, int stream) { @@ -680,5 +721,4 @@ bool gpu_has_xmx(sycl::device &dev); void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const ggml_sycl_op_flatten_t op); - #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 05b01db2d8b05..86b200e07030f 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -125,6 +125,25 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, } } +template +static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK4_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % 2 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q4_0_reorder(vx, y, k, item_ct1); + }); + +} + template static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -452,10 +471,15 @@ static void convert_unary_sycl(const void *__restrict__ vx, } } -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) { switch (type) { case GGML_TYPE_Q4_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q4_1: return dequantize_block_sycl; case GGML_TYPE_Q5_0: @@ -499,10 +523,15 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { } } -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { switch (type) { case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_row_q4_0_sycl; + } case GGML_TYPE_Q4_1: return dequantize_row_q4_1_sycl; case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 0ce2874aaaef9..355dae22b4075 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -21,7 +21,7 @@ using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type); -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type); +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst); +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst); #endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index b8304c3a274a2..651c2160d248f 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -16,6 +16,8 @@ #include "common.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v); static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { @@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + // const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib); + + const int vui = *((const uint8_t *)qs+iqs); + + v.x() = vui & 0xF; + v.y() = vui >> 4; + +#ifdef GGML_SYCL_F16 + // v = v - {8.0f, 8.0f}; + // v = v * {d, d}; + v.s0() = (v.s0() - 8.0f) * d; + v.s1() = (v.s1() - 8.0f) * d; + +#else + v.x() = (v.x() - 8.0f) * d; + v.y() = (v.y() - 8.0f) * d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_1 * x = (const block_q4_1 *) vx; @@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri } } +template +static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + auto k=nb32; + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK4_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK4_0; + + auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK4_0 / 2; ++l) { + int vq = qs[l]; + y_ptr[l + 0] = d * ((vq & 0xF) - 8); + y_ptr[l + 16] = d * ((vq >> 4) - 8); + } + +} + template static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 0d097357ce79b..99d3859de8979 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,7 +3,6 @@ #include "dequantize.hpp" #include "presets.hpp" - static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * } } +template +static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int tid = item_ct1.get_local_id(2); + + + const int ncols_left = ncols % (QK4_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16 + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_SYCL_F16 + const char *d_ptr = (const char*)vx+ncols*nrows/2; + int i=0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + for (; i < ncols; i += iter_stride) { + if (tid>=ncols_left/QK4_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + // sum up partial sums and write back result + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif // GGML_SYCL_F16 + } +} + static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec_reorder( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, @@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec( const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; - GGML_ASSERT(src1->type == GGML_TYPE_F32); // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_SYCL_F16 @@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( if (src1_convert_f16) { src1_dfloat = src1_dfloat_a.alloc(ne00); - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream); } @@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( switch (src0->type) { case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu*)dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_1: dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); @@ -1020,4 +1151,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec( GGML_UNUSED(src1_ddq_i); GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp new file mode 100644 index 0000000000000..51c19f6b3b90c --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -0,0 +1,308 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "ggml-impl.h" +#include "common.hpp" +#include "dequantize.hpp" +#include "getrows.hpp" + + +template +static void k_get_rows( + const void * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); +} + +template +static void k_get_rows_reorder( + const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + auto ncols = ne00; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + + const int src0_off = i01 * ncols + i00; + const int ib = src0_off / QK4_0; // block index + const int iqs = (i00%qk)/qr; // x quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); + + GGML_UNUSED(nb01); + GGML_UNUSED(nb02); + GGML_UNUSED(nb03); +} + +template +static void k_get_rows_float( + const src0_t * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = src0_row[i00]; +} + +template +static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows( + src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +template +static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + const uint8_t* src0_q = (const uint8_t*)src0_dd; + const size_t ncols = ne00; + const size_t nrows = ne01; + const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ + k_get_rows_reorder( + src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + + +template +static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const src0_t *src0_dd, const int32_t *src1_dd, + float *dst_dd, queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE; + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + } + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_d, const float *src1_d, + float *dst_d, const queue_ptr &stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + const int32_t * src1_i32 = (const int32_t *) src1_d; + + switch (src0->type) { + case GGML_TYPE_F16: + get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, + src1_i32, dst_d, stream); + break; + case GGML_TYPE_F32: + get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q4_0: + if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) { + get_rows_sycl_reorder(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + } else { + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + } + break; + case GGML_TYPE_Q4_1: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + default: + // TODO: k-quants + GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + break; + } +} + diff --git a/ggml/src/ggml-sycl/getrows.hpp b/ggml/src/ggml-sycl/getrows.hpp new file mode 100644 index 0000000000000..cdbe6c2f41bc3 --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.hpp @@ -0,0 +1,23 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GETROWS_HPP +#define GGML_SYCL_GETROWS_HPP + +#include "common.hpp" + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_d, const float *src1_d, + float *dst_d, const queue_ptr &stream); + +#endif // GGML_SYCL_GETROWS_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d4c97ad17b8a0..792e0569ca6cf 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -39,9 +39,12 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/sycl_hw.hpp" +#include "ggml-sycl/getrows.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; +int g_ggml_sycl_disable_optimize = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -64,14 +67,18 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; + sycl::device device = dpct::dev_mgr::instance().get_device(i); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( - prop, dpct::dev_mgr::instance().get_device(i)))); + prop, device))); info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + info.devices[i].hw_info = get_device_hw_info(&device); + info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } @@ -110,6 +117,27 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) global_mem_size, device.get_info().c_str()); } +void print_device_opt_feature(int device_count) { + GGML_LOG_INFO("SYCL Optimization Feature:\n"); + GGML_LOG_INFO( + "|ID| Device Type|Reorder|\n"); + GGML_LOG_INFO( + "|--|-------------------|-------|\n"); + std::map DeviceNums; + for (int id = 0; id < device_count; ++id) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + std::string backend_type = get_device_backend_and_type(device); + int type_id = DeviceNums[backend_type]++; + std::stringstream device_type; + device_type << "[" << backend_type << ":" << std::to_string(type_id) + << "]"; + std::string device_type_s = device_type.str(); + device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), ""); + GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(), + ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N"); + } + +} void ggml_backend_sycl_print_sycl_devices() { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); int device_count = dpct::dev_mgr::instance().device_count(); @@ -138,6 +166,8 @@ void ggml_backend_sycl_print_sycl_devices() { << "]"; print_device_detail(id, device, device_type.str()); } + + print_device_opt_feature(device_count); } static inline int get_sycl_env(const char *env_name, int default_val) { @@ -159,17 +189,21 @@ static void ggml_check_sycl() try { if (!initialized) { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); + g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); - GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); + GGML_LOG_INFO("Running with Environment Variables:\n"); + GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); + GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); + GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) - GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n"); + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); #else - GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n"); + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); #endif #if defined(GGML_SYCL_F16) - GGML_LOG_INFO("GGML_SYCL_F16: yes\n"); + GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); #else - GGML_LOG_INFO("GGML_SYCL_F16: no\n"); + GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); #endif /* NOT REMOVE, keep it for next optimize for XMX. @@ -241,19 +275,27 @@ struct ggml_backend_sycl_buffer_context { void * dev_ptr = nullptr; queue_ptr stream; std::string name; + optimize_feature opt_feature; + std::vector tensor_extras; - ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : + ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : device(device), dev_ptr(dev_ptr), stream(stream) { check_allow_gpu_index(device); name = (GGML_SYCL_NAME + std::to_string(device)); + opt_feature = ggml_sycl_info().devices[device].opt_feature; } - ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); } + + //release extra used by tensors + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + release_extra_gpu(extra); + } + } }; @@ -291,6 +333,9 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, return; } + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. if (ggml_is_quantized(tensor->type)) { // initialize padding to 0 to avoid possible NaN values @@ -316,7 +361,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) try { ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - ggml_sycl_set_device(ctx->device); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); SYCL_CHECK( @@ -660,32 +704,7 @@ struct ggml_backend_sycl_split_buffer_type_context { struct ggml_backend_sycl_split_buffer_context { ~ggml_backend_sycl_split_buffer_context() try { for (ggml_tensor_extra_gpu * extra : tensor_extras) { - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { - if (extra->events[i][is] != nullptr) { - /* - DPCT1009:206: SYCL uses exceptions to report errors and - does not use the error codes. The original code was - commented out and a warning string was inserted. You - need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::destroy_event(extra->events[i][is]))); - } - } - if (extra->data_device[i] != nullptr) { - /* - DPCT1009:207: SYCL uses exceptions to report errors and does - not use the error codes. The original code was commented out - and a warning string was inserted. You need to rewrite this - code. - */ - ggml_sycl_set_device(i); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( - extra->data_device[i], *(streams[i])))); - } - } - delete extra; + release_extra_gpu(extra, streams); } } catch (sycl::exception const &exc) { @@ -723,7 +742,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; ctx->tensor_extras.push_back(extra); - ctx->streams.push_back(&(dpct::get_current_device().default_queue())); + ctx->streams.push_back(&(dpct::get_current_device().default_queue())); for (int i = 0; i < ggml_sycl_info().device_count; ++i) { int64_t row_low, row_high; @@ -1337,83 +1356,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, reinterpret_cast(y[ib].ds.y()) = sum; } -template -static void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12, - const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { - - const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2)) * - 2; - const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) / - ne12; - const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) % - ne12; - - if (i00 >= ne00) { - return; - } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; - - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); - - dst_row[iybs + iqs + 0] = v.x(); - dst_row[iybs + iqs + y_offset] = v.y(); -} - -template -static void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12, - const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { - - const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2); - const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) / - ne12; - const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) % - ne12; - - if (i00 >= ne00) { - return; - } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = src0_row[i00]; -} - static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1896,81 +1838,6 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -template -static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const void *src0_dd, - const int32_t *src1_dd, float *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); - const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); - const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); - - // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); - - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); - - GGML_ASSERT(ne00 % 2 == 0); - - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_get_rows( - src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, - s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); - }); - - GGML_UNUSED(dst); - GGML_UNUSED(ctx); -} - -template -static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); - const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE; - const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); - - // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); - - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); - - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, - s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); - }); - } - - GGML_UNUSED(dst); - GGML_UNUSED(ctx); -} - static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, const int ky, const int kx_padded, queue_ptr stream) { @@ -2494,52 +2361,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_d, const float *src1_d, - float *dst_d, const queue_ptr &stream) { - - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; - - switch (src0->type) { - case GGML_TYPE_F16: - get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, - src1_i32, dst_d, stream); - break; - case GGML_TYPE_F32: - get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q4_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q4_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q5_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q5_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q8_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - default: - // TODO: k-quants - GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); - GGML_ABORT("fatal error"); - break; - } -} - - static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, @@ -2589,11 +2410,10 @@ inline void ggml_sycl_op_mul_mat_sycl( if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); if (src0->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = row_diff*ne00; src0_as_f16.alloc(ne); @@ -2605,7 +2425,7 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = src1_ncols*ne10; src1_as_f16.alloc(ne); @@ -2626,13 +2446,13 @@ inline void ggml_sycl_op_mul_mat_sycl( src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, dpct::library_data_t::real_half))); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else auto dnnl_stream = ctx.stream_dnnl(stream); DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); #endif } @@ -2641,13 +2461,13 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); if (src0->type != GGML_TYPE_F32) { - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src0_ddq_as_f32.alloc(row_diff*ne00); to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); } if (src1->type != GGML_TYPE_F32) { - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src1_ddq_as_f32.alloc(src1_ncols*ne10); to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); @@ -3085,7 +2905,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { continue; @@ -3393,7 +3212,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, // convert src1 to fp16 ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3509,6 +3328,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) { } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; @@ -3570,6 +3390,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); } else if (use_mul_mat_vec_q) { ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); } else if (use_mul_mat_q) { @@ -4251,10 +4072,72 @@ catch (sycl::exception const &exc) { std::exit(1); } +void reorder_qw(char *data_device, const int ncols, const int nrows, + size_t size, size_t offset, dpct::queue_ptr stream) { + auto tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) + .wait())); + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); + int offset_blks = offset / sizeof(block_q4_0); + auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; + + stream->parallel_for( + size / sizeof(block_q4_0), + [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + const block_q4_0* x = (const block_q4_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK4_0/2; j ++) + { + *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + + sycl::free(tmp_buf, *stream); +} + +void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) { + char*data_device = (char*)src0->data; + size_t ncols = src0->ne[0]; + size_t nrows = src0->ne[1]; + size_t size = ggml_nbytes(src0); + + reorder_qw(data_device, ncols, nrows, size, 0, stream); +} + +void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) { + ggml_tensor *src0 = dst->src[0]; + ggml_tensor *src1 = dst->src[1]; + + if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 && + src1->ne[2]==1 && src1->ne[3]==1) { + reorder_qw(src0, stream); + ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; + GGML_ASSERT(extra); + extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. + } +} + +void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) { + dpct::queue_ptr stream = ctx->stream(); + if (ctx->optimized_graph) { + return; + } + ctx->optimized_graph = true; + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream); + } +} static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_sycl_set_main_device(sycl_ctx->device); + if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx); for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp new file mode 100644 index 0000000000000..da121ffc261e8 --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -0,0 +1,13 @@ +#include "sycl_hw.hpp" + + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { + sycl_hw_info res; + int32_t id = device_ptr->get_info(); + res.device_id = id; + + syclex::architecture arch = device_ptr->get_info(); + res.arch = arch; + + return res; +} diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp new file mode 100644 index 0000000000000..bf689450ce61f --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -0,0 +1,23 @@ +#ifndef SYCL_HW_HPP +#define SYCL_HW_HPP + +#include +#include +#include +#include + +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +struct sycl_hw_info { + syclex::architecture arch; + int32_t device_id; +}; + +bool is_in_vector(std::vector &vec, int item); + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr); + + +#endif // SYCL_HW_HPP