Skip to content

Commit

Permalink
[SYCL] Optimize mul_mat for Q4_0 on Intel GPU (#12035)
Browse files Browse the repository at this point in the history
* opt performance by reorder for Intel GPU

* detect hw type and save opt feature, and print opt feature

* correct name

* support optimize graph once when compute graph, record the opt status in tensor->extra, make CI passed

* add env variable GGML_SYCL_DISABLE_OPT for debug

* use syclex::architecture replace the custom hw define, update the guide for GGML_SYCL_DISABLE_OPT

* add performance data

* mv getrows functions to separeted files

* fix global variables

---------

Co-authored-by: arthw <[email protected]>
  • Loading branch information
NeoZhangJianyu and arthw authored Feb 24, 2025
1 parent 651adf4 commit 08d5986
Show file tree
Hide file tree
Showing 14 changed files with 803 additions and 266 deletions.
16 changes: 14 additions & 2 deletions docs/backend/SYCL.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ The following release is verified with good quality:

## News

- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
|-|-|-|-|
|PVC 1550|39|73|+87%|
|Flex 170|39|50|+28%|
|Arc770|42|55|+30%|
|MTL|13|16|+23%|
|ARL-H|14|17|+21%|

- 2024.11
- Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.

Expand Down Expand Up @@ -97,8 +107,8 @@ SYCL backend supports Intel GPU Family:
| Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
| Intel iGPU | Support | iGPU in 13700k, i5-1250P, i7-1260P, i7-1165G7 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake |
| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |

*Notes:*

Expand Down Expand Up @@ -660,8 +670,10 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |


## Known Issues

- `Split-mode:[row]` is not supported.
Expand Down
4 changes: 2 additions & 2 deletions examples/sycl/run-llama2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# MIT license
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh

#export GGML_SYCL_DEBUG=1
Expand All @@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=33
CONEXT=8192
CONEXT=4096

if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}")

if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
endif()
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-sycl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,20 @@ catch (sycl::exception const &exc) {
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}


void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
if (extra->events[i][is] != nullptr) {
SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is])));
}
}
if (extra->data_device[i] != nullptr && streams.size()>0) {
ggml_sycl_set_device(i);
SYCL_CHECK(
CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i]))));
}
}
delete extra;
}
58 changes: 49 additions & 9 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "dpct/helper.hpp"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "sycl_hw.hpp"


#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#include "dnnl_sycl.hpp"
Expand All @@ -35,7 +38,10 @@
void* ggml_sycl_host_malloc(size_t size);
void ggml_sycl_host_free(void* ptr);


extern int g_ggml_sycl_debug;
extern int g_ggml_sycl_disable_optimize;

#define GGML_SYCL_DEBUG(...) \
do { \
if (g_ggml_sycl_debug) \
Expand Down Expand Up @@ -182,18 +188,24 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try {
}

//////////////////////
struct optimize_feature {
bool reorder=false;
};

struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
// size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t total_vram;
sycl_hw_info hw_info;
optimize_feature opt_feature;
};


struct ggml_sycl_device_info {
int device_count;

struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
// size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t total_vram;
};

sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};

std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
Expand Down Expand Up @@ -260,17 +272,46 @@ struct ggml_tensor_extra_gpu {
// tensors
dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
[GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
optimize_feature optimized_feature;
};

void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});

inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
optimize_feature opt;

opt.reorder =
(arch == syclex::architecture::intel_gpu_dg1 ||
arch == syclex::architecture::intel_gpu_acm_g10 ||
arch == syclex::architecture::intel_gpu_acm_g11 ||
arch == syclex::architecture::intel_gpu_acm_g12 ||
arch == syclex::architecture::intel_gpu_pvc ||
arch == syclex::architecture::intel_gpu_pvc_vg ||
arch == syclex::architecture::intel_gpu_mtl_u ||
arch == syclex::architecture::intel_gpu_mtl_s ||
arch == syclex::architecture::intel_gpu_mtl_h ||
arch == syclex::architecture::intel_gpu_arl_u ||
arch == syclex::architecture::intel_gpu_arl_s ||
arch == syclex::architecture::intel_gpu_arl_h ||
arch == syclex::architecture::intel_gpu_bmg_g21 ||
arch == syclex::architecture::intel_gpu_lnl_m
);

return opt;
}

struct ggml_backend_sycl_context {
int device;
std::string name;
optimize_feature opt_feature;
bool optimized_graph=false;

queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };

explicit ggml_backend_sycl_context(int device) :
device(device),
name(GGML_SYCL_NAME + std::to_string(device)) {
opt_feature = ggml_sycl_info().devices[device].opt_feature;
}

queue_ptr stream(int device, int stream) {
Expand Down Expand Up @@ -680,5 +721,4 @@ bool gpu_has_xmx(sycl::device &dev);
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const ggml_sycl_op_flatten_t op);

#endif // GGML_SYCL_COMMON_HPP
37 changes: 33 additions & 4 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
}
}

template <typename dst_t>
static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {

dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});

}

template <typename dst_t>
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
Expand Down Expand Up @@ -452,10 +471,15 @@ static void convert_unary_sycl(const void *__restrict__ vx,
}
}

to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q4_0_sycl_reorder;
} else {
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
}
case GGML_TYPE_Q4_1:
return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -499,10 +523,15 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
}
}

to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_sycl;
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q4_0_sycl_reorder;
} else {
return dequantize_row_q4_0_sycl;
}
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_sycl;
case GGML_TYPE_Q5_0:
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-sycl/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
typedef to_t_sycl_t<float> to_fp32_sycl_t;
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;

to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);

#endif // GGML_SYCL_CONVERT_HPP
55 changes: 55 additions & 0 deletions ggml/src/ggml-sycl/dequantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "common.hpp"

typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
const int iqs, dfloat2 &v);

static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
Expand All @@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
#endif // GGML_SYCL_F16
}

static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
const int iqs, dfloat2 &v) {
// const block_q4_0 * x = (const block_q4_0 *) vx;

const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);

const int vui = *((const uint8_t *)qs+iqs);

v.x() = vui & 0xF;
v.y() = vui >> 4;

#ifdef GGML_SYCL_F16
// v = v - {8.0f, 8.0f};
// v = v * {d, d};
v.s0() = (v.s0() - 8.0f) * d;
v.s1() = (v.s1() - 8.0f) * d;

#else
v.x() = (v.x() - 8.0f) * d;
v.y() = (v.y() - 8.0f) * d;
#endif // GGML_SYCL_F16
}

static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q4_1 * x = (const block_q4_1 *) vx;
Expand Down Expand Up @@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
}
}

template<typename dst_t>
static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) {

const int64_t i = item_ct1.get_group(2);
auto k=nb32;
// assume 32 threads
const int64_t tid = item_ct1.get_local_id(2);
const int lane_ib = i * WARP_SIZE + tid;

if (lane_ib >= k / QK4_0) {
return;
}

dst_t * y_ptr = yy + lane_ib * QK4_0;

auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;

const float d = float(*s_ptr);

#pragma unroll
for (int l = 0; l < QK4_0 / 2; ++l) {
int vq = qs[l];
y_ptr[l + 0] = d * ((vq & 0xF) - 8);
y_ptr[l + 16] = d * ((vq >> 4) - 8);
}

}

template<typename dst_t>
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) {
Expand Down
Loading

0 comments on commit 08d5986

Please sign in to comment.