diff --git a/docs/qbits.md b/docs/qbits.md index 00bd2fb083f..1e9a21548da 100644 --- a/docs/qbits.md +++ b/docs/qbits.md @@ -65,3 +65,11 @@ qbits.woq_linear( activation, pack_weight, bias, output, n, add_bias, compute_type, weight_type, scale_type, asym) ``` please refer [here](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits_ut) for more QBits operators usage. + +## Pytorch version constrain +If user wants to use QBits, the Pytorch version must meet ITREX requirements, here are the constrains: + +| ITREX version | Pytorch version | +| :-----------: | :-------------: | +| v1.4 | 2.2.0+cpu | +| v1.4.1 | 2.2.0+cpu | diff --git a/intel_extension_for_transformers/qbits/CMakeLists.txt b/intel_extension_for_transformers/qbits/CMakeLists.txt index 2aee942842b..5ac6c0e68e4 100755 --- a/intel_extension_for_transformers/qbits/CMakeLists.txt +++ b/intel_extension_for_transformers/qbits/CMakeLists.txt @@ -37,12 +37,15 @@ find_package(PythonLibs 3 REQUIRED) endif() include(FindOpenMP) +set(BTLA_ENABLE_OPENMP ON CACHE BOOL "BesTLA enable compiling OpenMP threading") add_subdirectory(dispatcher) add_subdirectory(../transformers/runtime/third_party/pybind11 pybind11) file(GLOB HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) file(GLOB qbits_src ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) +add_compile_options(-flto=auto) + # Link against LibTorch pybind11_add_module(qbits_py ${qbits_src}) target_compile_features(qbits_py PRIVATE cxx_std_14) diff --git a/intel_extension_for_transformers/qbits/dispatcher/CMakeLists.txt b/intel_extension_for_transformers/qbits/dispatcher/CMakeLists.txt index 240526ca0e1..2789f4a3ca6 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/CMakeLists.txt +++ b/intel_extension_for_transformers/qbits/dispatcher/CMakeLists.txt @@ -35,5 +35,5 @@ endif() set_target_properties(bestla_dispatcher PROPERTIES POSITION_INDEPENDENTBTLA_CODE ON) set_target_properties(bestla_dispatcher PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla::bestla) +target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla) set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "") diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_customop.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_customop.hpp index 10958560712..0df39eaba6d 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_customop.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_customop.hpp @@ -20,17 +20,17 @@ template inline BTLA_CODE alphabeta_dt_cvt_process(float* tmp_dst, const int cachestep, const int M_offset, const int N_offset, - const int M, const int N, const Param& _param) { + const int M, const int N, const Param& _param) { auto DOffset = M_offset * _param.ldd + N_offset; auto dptr = reinterpret_cast(_param.D) + DOffset; bestla::kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, tmp_dst, cachestep, _param.beta, dptr, - _param.ldd, tmp_dst, cachestep, M, N); + _param.ldd, tmp_dst, cachestep, M, N); auto COffset = M_offset * _param.ldc + N_offset; auto cptr = reinterpret_cast(_param.C) + COffset; if constexpr (std::is_same_v) { return bestla::kernel::wrapper::Memcpy2D::template forward(tmp_dst, cptr, M, N, cachestep, - _param.ldc, NULL); + _param.ldc, NULL); } if constexpr (std::is_same_v) { return bestla::kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( @@ -47,8 +47,8 @@ class AlphaBetaProcess { int ldc, ldd; float alpha, beta; }; - BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) { + static BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) { return alphabeta_dt_cvt_process(cacheptr, cachestep, M_offset, N_offset, M, N, _param); } }; diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp index fd8b25f68b2..1f33cfc663b 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -62,7 +62,8 @@ struct woq_runtime_ctx { static std::map wei2bestladt_map{{"int8", BTLA_DTYPE::S8}, {"int4_clip", BTLA_DTYPE::S4_CLIP}, - {"int4_fullrange", BTLA_DTYPE::S4_FULLRANGE}, + {"int3_clip", BTLA_DTYPE::S3_CLIP}, + {"int2_clip", BTLA_DTYPE::S2_CLIP}, {"nf4", BTLA_DTYPE::F4_NF4}, {"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB}, {"fp4_e2m1", BTLA_DTYPE::F4_E2M1}, diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp index d84b82a9cc5..8a0c99b3b3a 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp @@ -26,10 +26,26 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()-> inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); } inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); } +class qbits_threading { + public: + static bestla::parallel::IThreading* get() { + GetCPUDevice(); + static bestla::parallel::StdThreading OptmizedThreading; + static bestla::parallel::OMPThreading DefaultThreading; + if (!_cd->isHybrid()) { + return &DefaultThreading; + } + return &OptmizedThreading; + } + + static void set_threads(int n_thread) { get()->set_threads(n_thread); } +}; + class env_initer { public: env_initer() { if (check_amx()) bestla::utils::request_perm_xtile_data(); + qbits_threading::set_threads(bestla::device::CpuDevice::getInstance()->getThreads()); verbose = std::getenv("QBITS_VERBOSE") != nullptr; FLAGS_caffe2_log_level = 0; } @@ -56,7 +72,7 @@ class Timer { high_resolution_clock::time_point m_end; }; static Timer timer; -static bestla::parallel::OMPThreading DefaultThreading(bestla::device::CpuDevice::getInstance()->getThreads()); + string get_torch_dt_name(torch::Tensor* tensor); } // namespace dispatcher_utils diff --git a/intel_extension_for_transformers/qbits/dispatcher/neural_speed.cmake b/intel_extension_for_transformers/qbits/dispatcher/neural_speed.cmake index 7a8c0ce591c..93605924ed4 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/neural_speed.cmake +++ b/intel_extension_for_transformers/qbits/dispatcher/neural_speed.cmake @@ -1,5 +1,5 @@ set(NEURAL_SPEED_URL https://github.com/intel/neural-speed.git) -set(NEURAL_SPEED_TAG bestlav0.1) +set(NEURAL_SPEED_TAG 2f7943681e02c6e87a4c70c3925327f00194c78f) FetchContent_Declare( neural_speed diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_gemm_dispatcher.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_gemm_dispatcher.cpp index fe4e4a92019..887bf8251da 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_gemm_dispatcher.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_gemm_dispatcher.cpp @@ -40,17 +40,17 @@ void do_gemm(bestla_gemm_runtime_ctx* ctx) { packw.assign(tmpbuf); if (ctx->matB_trans) { launcher.mProB.packWeightTranspose(ctx->n, ctx->k, {reinterpret_cast(ctx->matB->data_ptr()), ctx->k, &packw}, - &dispatcher_utils::DefaultThreading); + dispatcher_utils::qbits_threading::get()); } else { launcher.mProB.packWeight(ctx->n, ctx->k, {reinterpret_cast(ctx->matB->data_ptr()), ctx->n, &packw}, - &dispatcher_utils::DefaultThreading); + dispatcher_utils::qbits_threading::get()); } bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k); typename Launcher::Param args{gp, {reinterpret_cast(ctx->matA->data_ptr()), ctx->k}, {reinterpret_cast(ctx->matB->data_ptr()), ctx->n, &packw}, {reinterpret_cast(ctx->matC->data_ptr()), ctx->n}}; - bestla::parallel::GemmRun(launcher, args, &dispatcher_utils::DefaultThreading); + bestla::parallel::GemmRun(launcher, args, dispatcher_utils::qbits_threading::get()); bestla::utils::afree(tmpbuf); } diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp index 3d1067de7e6..399deaba7e0 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp @@ -28,9 +28,9 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx *(ctx->output) = torch::empty(qpackw.mSize, torch::kInt8); qpackw.assign(ctx->output->data_ptr()); if (p->enable_act_shuffle) - ker.setShuffleIndices(ctx->g_idx->data_ptr(), &qpackw, &dispatcher_utils::DefaultThreading); + ker.setShuffleIndices(ctx->g_idx->data_ptr(), &qpackw, dispatcher_utils::qbits_threading::get()); ker.packQWeight(ctx->n, ctx->k, ctx->qweight->data_ptr(), ctx->n, ctx->scale->data_ptr(), - p->asym ? ctx->zp->data_ptr() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading); + p->asym ? ctx->zp->data_ptr() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get()); } std::string get_dtype_str(BTLA_DTYPE dtype) { @@ -41,8 +41,10 @@ std::string get_dtype_str(BTLA_DTYPE dtype) { return "bf16"; case BTLA_DTYPE::S4_CLIP: return "int4_clip"; - case BTLA_DTYPE::S4_FULLRANGE: - return "int4_fullrange"; + case BTLA_DTYPE::S3_CLIP: + return "int3_clip"; + case BTLA_DTYPE::S2_CLIP: + return "int2_clip"; case BTLA_DTYPE::F4_NF4: return "nf4"; case BTLA_DTYPE::F4_E2M1: @@ -66,7 +68,6 @@ std::string get_dtype_str(BTLA_DTYPE dtype) { std::string get_cmpt_str(bestla::gemm::CompType cmpt) { using bestla::gemm::CompType; switch (cmpt) { - case CompType::COMP_INT8_US_INT32: case CompType::COMP_INT8_US_FP32: return "int8"; case CompType::COMP_FP32: @@ -182,43 +183,34 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { } void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange", + // TODO(zhe): elegant impl. + TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip", "Qbits: only support Integer WOQ in PACKQ"); - // NTILE & compute-dtype determine the padsize. - // in qbits: - // avx_vnni/avx512f_vnni/amx-int8 NTILE==48, compute-dtype=int8; - // avx2/avx512f NTILE==48, compute-dtype=fp32; - // amx-bf16 NTILE==64, compute-dtype=bf16. - if (task == WOQ_GET_PACKW_SIZE) { - if (p->compute_type == "int8") - return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); - if (p->compute_type == "fp32") - return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); - if (p->compute_type == "bf16") - return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx, task); - } - if (p->compute_type == "int8") { - if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); + if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) { + return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); } if (dispatcher_utils::check_avx512_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) { return execute_qpack, BTLA_ISA::AVX512_VNNI>(p, ctx, task); } - if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx, task); + if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) { + return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx, task); + } + if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) { + return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize, - ", ISA support vnni:", dispatcher_utils::check_avx_vnni()); + ", ISA support avx2:", dispatcher_utils::check_avx2()); } if (p->compute_type == "fp32") { if (dispatcher_utils::check_avx512f()) { return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); } if (dispatcher_utils::check_avx2()) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); + return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32"); } diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp index efecbb313b4..f9864ddece0 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp @@ -43,6 +43,12 @@ concept quant_PrologueA = requires { requires !std::is_same_v; }; +template +constexpr bool is_int8_cmpt_gemmcore() { + return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || + GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v>; +} + template void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start(); @@ -53,16 +59,16 @@ void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { kernel.unpackTransposeWeight(ctx->deseries_wei->mN, ctx->deseries_wei->mK, dynamic_cast(ctx->deseries_wei), ctx->output->data_ptr(), ctx->deseries_wei->mK, - &dispatcher_utils::DefaultThreading); + dispatcher_utils::qbits_threading::get()); } else { kernel.unpackWeight(ctx->deseries_wei->mN, ctx->deseries_wei->mK, dynamic_cast(ctx->deseries_wei), - ctx->output->data_ptr(), ctx->deseries_wei->mN, &dispatcher_utils::DefaultThreading); + ctx->output->data_ptr(), ctx->deseries_wei->mN, + dispatcher_utils::qbits_threading::get()); } } -// TODO(zhe): weight+scale combination check. template void quantize_to_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start(); @@ -72,7 +78,6 @@ void quantize_to_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if constexpr (std::is_same_v) { TORCH_CHECK(p->scale_type == "fp32" || p->scale_type == "bf16", "Qbits: scale_type must be fp32/bf16 in NInteger Weight."); - if (p->scale_type == "bf16") TORCH_CHECK(!p->asym, "Qbits: asym is not supported when scale_type==bf16 currently."); packedw = launcher.mProB.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map[p->weight_type], scale2bestladt_map[p->scale_type], BTLA_DTYPE::BF16, p->asym); } else if constexpr (std::is_same_v) { @@ -92,10 +97,10 @@ void quantize_to_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { packedw.assign(ctx->output->data_ptr()); if (ctx->transpose) { launcher.mProB.packTransposeWeight(ctx->n, ctx->k, ctx->weight->data_ptr(), ctx->k, &packedw, - &dispatcher_utils::DefaultThreading); + dispatcher_utils::qbits_threading::get()); } else { launcher.mProB.packWeight(ctx->n, ctx->k, ctx->weight->data_ptr(), ctx->n, &packedw, - &dispatcher_utils::DefaultThreading); + dispatcher_utils::qbits_threading::get()); } if (dispatcher_utils::initer.verbose) { dispatcher_utils::timer.stop(); @@ -128,8 +133,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) { using StorageWeight = typename Launcher::PrologueB::StorageWeight; size_t asym_size = 0, shuf_size = 0; int8_t* tmpbuf = nullptr; - if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI) { + if constexpr (is_int8_cmpt_gemmcore()) { using Parallel = bestla::parallel::gemm::SchedulerKBlockS; bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize); StorageWeight* packedw = dynamic_cast(ctx->deseries_wei); @@ -140,17 +144,17 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) { if (packedw->ShfIndice()) { param_a.reordered->assign(tmpbuf + dyn_q_size); param_a.indices = packedw->ShfIndice(); - launcher.mProA.quantize(param_a, ctx->m, ctx->deseries_wei->mK, &dispatcher_utils::DefaultThreading); + launcher.mProA.quantize(param_a, ctx->m, ctx->deseries_wei->mK, dispatcher_utils::qbits_threading::get()); } typename Launcher::Param args{ gp, param_a, dynamic_cast(ctx->deseries_wei), param_epi}; if (packedw->ShfIndice()) { - bestla::parallel::GemmRun(launcher, args, &dispatcher_utils::DefaultThreading); + bestla::parallel::GemmRun(launcher, args, dispatcher_utils::qbits_threading::get()); } else { - bestla::parallel::GemmRunWithA(launcher, args, &dispatcher_utils::DefaultThreading); + bestla::parallel::GemmRunWithA(launcher, args, dispatcher_utils::qbits_threading::get()); } } else { - using Parallel = bestla::parallel::gemm::SchedulerKBlock; + using Parallel = bestla::parallel::gemm::SchedulerBase; StorageWeight* packedw = dynamic_cast(ctx->deseries_wei); if (p->asym || packedw->ShfIndice()) { if (p->asym) asym_size = param_a.reduce->mSize; @@ -170,18 +174,12 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) { bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize); typename Launcher::Param args{ - gp, - param_a, - dynamic_cast(ctx->deseries_wei), - {packedw->template SPtr(), packedw->SDtype(), packedw->CStep(), - p->asym ? packedw->template ZPtr() : nullptr, - p->asym ? param_a.reduce->template RPtr() : nullptr, p->asym ? param_a.reduce->lda : -1}, - param_epi}; + gp, param_a, dynamic_cast(ctx->deseries_wei), param_epi}; if (p->asym || packedw->ShfIndice()) { - bestla::parallel::GemmRunWithA(launcher, args, &dispatcher_utils::DefaultThreading); + bestla::parallel::GemmRunWithA(launcher, args, dispatcher_utils::qbits_threading::get()); } else { - bestla::parallel::GemmRun(launcher, args, &dispatcher_utils::DefaultThreading); + bestla::parallel::GemmRun(launcher, args, dispatcher_utils::qbits_threading::get()); } } if (tmpbuf != woq_workspace && tmpbuf != nullptr) bestla::utils::afree(tmpbuf); @@ -238,13 +236,11 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) { template class PrologueB, template class PrologueA, template class Epilogue> void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) { - if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI) { + if constexpr (is_int8_cmpt_gemmcore()) { using Launcher = bestla::wrapper::gemm::LauncherIntKBlock; return execute_task(p, ctx); } else { - using Launcher = bestla::wrapper::gemm::LauncherKBlock; + using Launcher = bestla::wrapper::gemm::LauncherBase; return execute_task(p, ctx); } } @@ -252,7 +248,6 @@ void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) { template class PrologueB, template class PrologueA, dispatcher_utils::QBITS_DT ACT_DT> void parse_store(woq_config_param* p, woq_runtime_ctx* ctx) { - auto constexpr ISA = GemmCore::ISA; if (p->dst_dt == dispatcher_utils::QBITS_FP32) { return parse_launcher(p, ctx); } @@ -265,8 +260,7 @@ template class Pro void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { using namespace bestla::prologue_a::gemm; if (p->src_dt == dispatcher_utils::QBITS_FP32) { - if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI) { + if constexpr (is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -275,8 +269,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { } } if (p->src_dt == dispatcher_utils::QBITS_BF16) { - if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI) { + if constexpr (is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -289,14 +282,14 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { template void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) { using namespace bestla::prologue_b::gemm; - if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange") { + if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip") { return parse_activation(p, ctx); } if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" || p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") { TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization."); - if constexpr (GemmCore::ISA != BTLA_ISA::AMX_INT8 && GemmCore::ISA != BTLA_ISA::AVX512_VNNI && - GemmCore::ISA != BTLA_ISA::AVX_VNNI) + if constexpr (!is_int8_cmpt_gemmcore()) return parse_activation(p, ctx); } TORCH_CHECK(false, @@ -308,26 +301,28 @@ void parse_gemm_core_online(woq_config_param* p, woq_runtime_ctx* ctx) { set_nk(ctx, ctx->weight); p->blocksize = p->blocksize == -1 ? ctx->k : p->blocksize; if (p->compute_type == "int8") { - TORCH_CHECK(p->asym == false, "Qbits: int8 compute_type doesn't support asym quantization currently.") - if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) { - return parse_weight>(p, ctx); + if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) { + return parse_weight>(p, ctx); } if (dispatcher_utils::check_avx512_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) { return parse_weight>(p, ctx); } - if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) { - return parse_weight>(p, ctx); + if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) { + return parse_weight>(p, ctx); + } + if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) { + return parse_weight>(p, ctx); } TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize, - ", ISA support vnni:", dispatcher_utils::check_avx_vnni()); + ", ISA support avx2:", dispatcher_utils::check_avx2()); } if (p->compute_type == "fp32") { if (dispatcher_utils::check_avx512f()) { return parse_weight>(p, ctx); } if (dispatcher_utils::check_avx2()) { - return parse_weight>(p, ctx); + return parse_weight>(p, ctx); } TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32"); } @@ -349,27 +344,26 @@ void parse_gemm_core_offline(woq_config_param* p, woq_runtime_ctx* ctx) { bestla::gemm::CoreAttr::NTILE_SHIFT); auto CType = bestla::gemm::CoreAttr::get_mask_val(ctx->deseries_wei->mCoreId, bestla::gemm::CoreAttr::COMP_MASK, bestla::gemm::CoreAttr::COMP_SHIFT); - if (CType == uint32_t(bestla::gemm::CompType::COMP_INT8_US_INT32)) { - TORCH_CHECK(p->asym == false, "Qbits: int8 compute_type doesn't support asym quantization currently.") - if (NTile == bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::NTILE && dispatcher_utils::check_amx()) { - return parse_weight>(p, ctx); - } - } if (CType == uint32_t(bestla::gemm::CompType::COMP_INT8_US_FP32)) { - TORCH_CHECK(p->asym == false, "Qbits: int8 compute_type doesn't support asym quantization currently.") + if (NTile == bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::NTILE && dispatcher_utils::check_amx()) { + return parse_weight>(p, ctx); + } if (NTile == bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::NTILE && dispatcher_utils::check_avx512_vnni()) { return parse_weight>(p, ctx); } - if (NTile == bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::NTILE && dispatcher_utils::check_avx_vnni()) { - return parse_weight>(p, ctx); + if (NTile == bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::NTILE && dispatcher_utils::check_avx_vnni()) { + return parse_weight>(p, ctx); + } + if (NTile == bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::NTILE && dispatcher_utils::check_avx2()) { + return parse_weight>(p, ctx); } } if (CType == uint32_t(bestla::gemm::CompType::COMP_FP32)) { if (NTile == bestla::gemm::SCoreRowNAvx512f<48, 8>::NTILE && dispatcher_utils::check_avx512f()) { return parse_weight>(p, ctx); } - if (NTile == bestla::gemm::SCoreRowNAvx2<48, 2>::NTILE && dispatcher_utils::check_avx2()) { - return parse_weight>(p, ctx); + if (NTile == bestla::gemm::SCoreRowNAvx2<24, 4>::NTILE && dispatcher_utils::check_avx2()) { + return parse_weight>(p, ctx); } } if (CType == uint32_t(bestla::gemm::CompType::COMP_BF16_FP32)) { @@ -392,6 +386,8 @@ void parse_gemm_core(woq_config_param* p, woq_runtime_ctx* ctx) { } void dispatch_woq_task(woq_config_param* p, woq_runtime_ctx* ctx, WOQ_TASK task) { + TORCH_CHECK(!(p->asym && (p->compute_type == "int8" && p->weight_type == "int8")), + "QBits: unsupported bestla_config, asym quantization in int8 compute_type with int8 weight_type."); switch (task) { case WOQ_QUANTIZE: return parse_gemm_core(p, ctx); diff --git a/intel_extension_for_transformers/qbits/qbits.cpp b/intel_extension_for_transformers/qbits/qbits.cpp index 2da823abb0c..1da0193fa78 100755 --- a/intel_extension_for_transformers/qbits/qbits.cpp +++ b/intel_extension_for_transformers/qbits/qbits.cpp @@ -144,6 +144,8 @@ static void set_woq_workspace(const torch::Tensor& workspace) { woq::set_woq_workspace(const_cast(&workspace)); } +static void set_qbits_threads(int64_t thread_num) { dispatcher_utils::qbits_threading::set_threads(thread_num); } + static void bestlaop_gemm(const torch::Tensor& matA, const torch::Tensor& matB, const torch::Tensor& matC, bool matB_trans) { TORCH_CHECK(matA.dim() == 2 && matB.dim() == 2 && matC.dim() == 2, @@ -185,6 +187,7 @@ PYBIND11_MODULE(qbits_py, m) { m.def("repack_quantized_weight", &repack_quantized_weight); m.def("get_packed_weight_size", &get_packed_weight_size); m.def("set_woq_workspace", &set_woq_workspace); + m.def("set_qbits_threads", &set_qbits_threads); m.def("matmul", &bestlaop_gemm); m.def("acquire_packed_weight_info", &acquire_packed_weight_info); m.def("dropout_fwd", &qbits_dropout_fwd); diff --git a/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py b/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py index 055cef76f11..b87e9933967 100644 --- a/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py +++ b/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py @@ -79,7 +79,7 @@ def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_t if compute_type == "fp32": assert (abs(ref_dst - tar_dst).max() < 0.03) elif compute_type == "bf16": - assert (abs(ref_dst - tar_dst).max() < 8) + assert (abs(ref_dst - tar_dst).max() < 9) else: assert (abs(ref_dst - tar_dst).max() < 10) packw_size = qbits.acquire_packed_weight_info( diff --git a/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py b/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py index cd645230f5f..f71fc929016 100644 --- a/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py +++ b/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py @@ -17,15 +17,14 @@ from ut_utils import * -cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4_clip": {"int8", "fp32", "bf16"}, "int4_fullrange": { - "int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, - "fp8_e5m2": {"fp32", "bf16"}, "fp8_e4m3": {"fp32", "bf16"} -} +cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4_clip": {"int8", "fp32", "bf16"}, "int3_clip": {"int8", "fp32", "bf16"}, "int2_clip": {"int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, + "fp8_e5m2": {"fp32", "bf16"}, "fp8_e4m3": {"fp32", "bf16"} + } -scale_configs = {"int8": {"fp32", "bf16"}, "int4_clip": {"fp32", "bf16"}, "int4_fullrange": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, +scale_configs = {"int8": {"fp32", "bf16"}, "int4_clip": {"fp32", "bf16"}, "int3_clip": {"fp32", "bf16"}, "int2_clip": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, "fp8_e5m2": {"fp32", "fp8_e8m0"}, "fp8_e4m3": {"fp32", "fp8_e8m0"}} -asym_configs = {"int8", "int4_clip", "int4_fullrange"} +asym_configs = {"int8", "int4_clip", "int3_clip", "int2_clip"} @capture_args @@ -34,7 +33,7 @@ @pytest.mark.parametrize("k", [512]) @pytest.mark.parametrize("blocksize", [128, -1]) @pytest.mark.parametrize("compute_type", ["int8", "bf16", "fp32"]) -@pytest.mark.parametrize("weight_type", ["int8", "int4_clip", "int4_fullrange", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"]) +@pytest.mark.parametrize("weight_type", ["int8", "int4_clip", "int3_clip", "int2_clip", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"]) @pytest.mark.parametrize("scale_type", ["fp32", "bf16", "fp8_e8m0"]) @pytest.mark.parametrize("asym", [True, False]) @pytest.mark.parametrize("transpose", [True, False]) @@ -42,9 +41,11 @@ @pytest.mark.parametrize("src_dt", ["fp32", "bf16"]) @pytest.mark.parametrize("dst_dt", ["fp32", "bf16"]) def test(m, n, k, blocksize, compute_type, weight_type, scale_type, asym, transpose, add_bias, src_dt, dst_dt, dump_tensor_info=True): + if compute_type == "int8" and weight_type == "int8" and (not qbits.check_isa_supported("AVX_VNNI")): + pytest.skip() if compute_type not in cmpt_configs[weight_type] or scale_type not in scale_configs[weight_type]: pytest.skip() - if asym and (weight_type not in asym_configs or compute_type == "int8" or scale_type != "fp32"): + if asym and (weight_type not in asym_configs or (compute_type == "int8" and weight_type == "int8")): pytest.skip() torch.manual_seed(0) ref_activation = torch.rand(m, k, dtype=torch.float) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py index f3905c2fd34..379a227e812 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py @@ -213,9 +213,8 @@ def set_weights_bias( else: g_idx = torch.empty(0, dtype=torch.int32) if q_config.bits == 4: - int_weight = (int_weight - 8) * 16 - gptq_scales = gptq_scales / 16 - gptq_zeros = (gptq_zeros - 8) * 16 + int_weight = (int_weight - 8) * 16 // 16 + gptq_zeros = (gptq_zeros - 8) * 16 // 16 if q_config.sym: gptq_zeros = torch.empty(0, dtype=torch.int8) @@ -344,13 +343,12 @@ def recover_int_weight(g_idx, int_weight): if scales_dtype is None: assert False, "scales dtype only support fp32." scales = qbits.acquire_packed_weight_info(self.weight, 9) - if bits == 4: - scales = scales * 16 + zp = qbits.acquire_packed_weight_info(self.weight, 11)[0] != 0 if zp: qzeros = qbits.acquire_packed_weight_info(self.weight, 10) if bits == 4: - qzeros = qzeros // 16 + 8 + qzeros = qzeros + 8 else: qzeros = (qzeros.to(torch.int32) + 128).to(torch.uint8) else: diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 362f47f0cfc..e9c50553482 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -281,7 +281,6 @@ def post_init_cpu(self): ) if self.bits == 4 and self.weight_dtype not in [ - "int4_fullrange", "int4_clip", "nf4", "fp4_e2m1_bnb", @@ -300,7 +299,6 @@ def post_init_cpu(self): elif self.weight_dtype not in [ "int8", - "int4_fullrange", "int4_clip", "nf4", "fp4_e2m1_bnb", @@ -310,7 +308,7 @@ def post_init_cpu(self): ]: raise ValueError( f"weight_dtype must be a string in " - f"'int8', 'int4_fullrange', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8_e5m2, fp8_e4m3'" + f"'int8', 'int4_clip', 'nf4', 'fp4_e2m1_bnb', 'fp4_e2m1', 'fp8_e5m2, fp8_e4m3'" ) if self.scale_dtype is not None and self.scale_dtype not in [ diff --git a/tests/CI/test_quantization.py b/tests/CI/test_quantization.py index 883c1a443dc..1086ea000a4 100644 --- a/tests/CI/test_quantization.py +++ b/tests/CI/test_quantization.py @@ -408,18 +408,17 @@ def test_quantization_for_llm(self): # weight-only # RTN - woq_config = RtnConfig(bits=4, weight_dtype="int4_fullrange") + woq_config = RtnConfig(bits=4) woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=woq_config, use_neural_speed=False ) woq_model.eval() output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.16387596726417542, rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.17631684243679047, rel_tol=1e-04)) # AWQ woq_config = AwqConfig(bits=4, - weight_dtype="int4_fullrange", zero_point=False, calib_iters=5, tokenizer=tokenizer @@ -431,13 +430,13 @@ def test_quantization_for_llm(self): ) woq_model.eval() output = woq_model(dummy_input) - self.assertTrue(isclose(float(output[0][0][0][0]), 0.17998121678829193 , rel_tol=1e-04)) + self.assertTrue(isclose(float(output[0][0][0][0]), 0.18019595742225647 , rel_tol=1e-04)) # TEQ - woq_config = TeqConfig(bits=4, weight_dtype="int4_fullrange", - calib_iters=5, - tokenizer=tokenizer, - ) + woq_config = TeqConfig(bits=4, + calib_iters=5, + tokenizer=tokenizer, + ) woq_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=woq_config, use_neural_speed=False diff --git a/tests/CI/test_weight_only.py b/tests/CI/test_weight_only.py index d2d02cd30f1..1952d8e5e69 100644 --- a/tests/CI/test_weight_only.py +++ b/tests/CI/test_weight_only.py @@ -86,9 +86,9 @@ def tearDownClass(cls) -> None: def test_woq_config(self): config = RtnConfig( - bits=4, weight_dtype="int4_fullrange", group_size=32) + bits=4, weight_dtype="int4_clip", group_size=32) diff_res = config.to_diff_dict() - ref_config = {'weight_dtype': 'int4_fullrange'} + ref_config = {'weight_dtype': 'int4_clip'} self.assertEqual(diff_res, ref_config) print(diff_res) print(config.to_dict()) @@ -133,10 +133,10 @@ def test_int8(self): def test_int4(self): raw_wei = torch.rand(2, 32, dtype=torch.float) compress_wei = qbits.quantize_to_packed_weight( - raw_wei, True, 32, "fp32", "int4_fullrange", "fp32", False) + raw_wei, True, 32, "fp32", "nf4", "fp32", False) revert_wei = torch.zeros(2, 32, dtype=torch.float) qbits.dequantize_packed_weight(compress_wei, revert_wei, True, - "fp32", "int4_fullrange", "fp32") + "fp32", "nf4", "fp32") for bias in [True, False]: model = M(with_bias=bias) with torch.no_grad(): @@ -146,7 +146,7 @@ def test_int4(self): with torch.no_grad(): model.linear.weight = torch.nn.Parameter(raw_wei) config = RtnConfig( - bits=4, weight_dtype="int4_fullrange", group_size=32) + bits=4, weight_dtype="nf4", group_size=32) config.post_init_cpu() convert_to_quantized_model(model, config) output_quant = model(activation)