Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] build is verrrrrrrrrrrrrrrrrrrry slow #945

Open
wongdi opened this issue May 11, 2024 · 27 comments
Open

[bug] build is verrrrrrrrrrrrrrrrrrrry slow #945

wongdi opened this issue May 11, 2024 · 27 comments

Comments

@wongdi
Copy link

wongdi commented May 11, 2024

I compiled with the latest source code, and the compilation was so slow that I had to fall back on commit 2.5.8. the previous version took me about 3-5 minutes to complete (70%CPU and 230GB memory usage), but this version barely sees the cpu working. what happened to him.

"MAX_JOBS" doesn't get the CPU excited either.

CentOS: 7.9.2009
Python: 3.10.14
GCC: 12.3.0
cmake: 3.27.9
nvcc: 12.2.140
wheel is OK

@CHDev93
Copy link

CHDev93 commented May 22, 2024

Did you install ninja? That sped things up for me considerably

@wongdi
Copy link
Author

wongdi commented May 22, 2024

Did you install ninja? That sped things up for me considerably

I have ninja 1.11.1.1, I think this may not be the cause of the problem because his compilation speed was good in the previous commit.

@HuBocheng
Copy link

Following your suggestion, I attempted to install version 2.8.7 of flash-attention. However, the build process is still very slow, with CPU usage remaining below 1%. What could be causing this?😭

pip install flash-attn==2.5.7 --no-build-isolation
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple, https://pypi.ngc.nvidia.com
Collecting flash-attn==2.5.7
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/21/cb/33a1f833ac4742c8adba063715bf769831f96d99dbbbb4be1b197b637872/flash_attn-2.5.7.tar.gz (2.5 MB)
     ━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 54.0 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torch in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from flash-attn==2.5.7) (2.3.0+cu118)
Collecting einops (from flash-attn==2.5.7)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)
     ━━━━━━━━━━━━━━━━ 43.2/43.2 kB 82.2 MB/s eta 0:00:00
Requirement already satisfied: packaging in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from flash-attn==2.5.7) (24.0)
Requirement already satisfied: ninja in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from flash-attn==2.5.7) (1.11.1.1)
Requirement already satisfied: filelock in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (4.12.0)
Requirement already satisfied: sympy in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (1.12)
Requirement already satisfied: networkx in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (3.2.1)
Requirement already satisfied: jinja2 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (3.1.3)
Requirement already satisfied: fsspec in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (2024.5.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.8.89 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.8.89)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.8.89 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.8.89)
Requirement already satisfied: nvidia-cuda-cupti-cu11==11.8.87 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.8.87)
Requirement already satisfied: nvidia-cudnn-cu11==8.7.0.84 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (8.7.0.84)
Requirement already satisfied: nvidia-cublas-cu11==11.11.3.6 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.11.3.6)
Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (10.9.0.58)
Requirement already satisfied: nvidia-curand-cu11==10.3.0.86 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (10.3.0.86)
Requirement already satisfied: nvidia-cusolver-cu11==11.4.1.48 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.4.1.48)
Requirement already satisfied: nvidia-cusparse-cu11==11.7.5.86 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.7.5.86)
Requirement already satisfied: nvidia-nccl-cu11==2.20.5 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (2.20.5)
Requirement already satisfied: nvidia-nvtx-cu11==11.8.86 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (11.8.86)
Requirement already satisfied: triton==2.3.0 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from torch->flash-attn==2.5.7) (2.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from jinja2->torch->flash-attn==2.5.7) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /zhuzixuan/conda/envs/bunny/lib/python3.10/site-packages (from sympy->torch->flash-attn==2.5.7) (1.3.0)
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... -

CPU:
top - 08:39:29 up 61 days, 2:08, 1 user, load avera
Tasks: 26 total, 1 running, 25 sleeping, 0 stopp
%Cpu(s): 0.5 us, 0.1 sy, 0.0 ni, 99.4 id, 0.0 wa,
top - 08:45:03 up 61 days, 2:14, 1 user, load avera
Tasks: 24 total, 1 running, 23 sleeping, 0 stopp
%Cpu(s): 1.1 us, 0.2 sy, 0.0 ni, 98.7 id, 0.0 wa,
MiB Mem : 515612.6 total, 491855.4 free, 19313.1 used
MiB Swap: 0.0 total, 0.0 free, 0.0 used

PID USER      PR  NI    VIRT    RES    SHR S 
386 root      20   0 1364196 498040  44128 S 
668 root      20   0 1045352  74208  38872 S 
  1 root      20   0    8408   1756   1472 S 
  7 root      20   0   12272   5592   4560 S 
 21 root      20   0   18012   4776   3556 S 
 33 root      20   0   10212   1828   1544 S 
382 root      20   0    9700   4360   4012 S 
769 root      20   0  853504  53268  38888 S 

14750 root 20 0 19496 10636 8960 S
14761 root 20 0 9896 4684 4292 S

@ComDec
Copy link

ComDec commented Jun 20, 2024

try to clone the repo and run python setup.py install instead. That's works for most of time. Check your top panel if there are multiple ccic processes. @CHDev93 @wongdi @HuBocheng @HuBocheng @no-execution @SiyangJ

@YudiZh
Copy link

YudiZh commented Sep 11, 2024

Have you solved this problem yet? I have encountered the same problem.

@SiyangJ
Copy link

SiyangJ commented Sep 11, 2024

My XPS 15 windows is taking hours to build...

@WonderRico
Copy link

I had the same issue. building flash-attn was slow and the CPU load was very low. Only 2 instances of the process "NVIDIA cicc" was running at the same time.
running "pip install ninja" seems to help as suggested before. now I have 10 instances of Nvidia cicc running, and my CPU is at 37% (Ryzen 9 7950X 3D). I guess it will be 5 times quicker now.
(windows 11 by the way)

@zhangyuqi-1
Copy link

The same issue, I couldn't build it all night, and now following the suggestion to revert to commit 2.5.8, the CPU was fully utilized, and the build succeeded.

@xFranv8
Copy link

xFranv8 commented Oct 8, 2024

Same issue here (32 RAM and i7 13th)

@luhuaei
Copy link

luhuaei commented Oct 16, 2024

Same issue here(jetson agx orin pip install flash-attn==2.5.8 --no-build-isolation --verbose --no-cache-dir)

@SaeedNajafi
Copy link

Super slow build.

@yshuolu
Copy link

yshuolu commented Oct 25, 2024

Same issue

@no-execution
Copy link

same
any solution?

@jarredou
Copy link

jarredou commented Nov 1, 2024

same

4 similar comments
@JerryYC
Copy link

JerryYC commented Nov 2, 2024

same

@tingwl0122
Copy link

same

@zhuofuAMZ
Copy link

same

@lixali
Copy link

lixali commented Nov 9, 2024

same

@hail75
Copy link

hail75 commented Nov 11, 2024

Same

@LittleHeroZZZX
Copy link

Same to me, only 4 of 20 cores are occupied

@back2yes
Copy link

Win10, with ninja installed from conda. The CPU load reaches 100%, 44 core ~ 100 GB RAM usage. The building took about 1 hour.

@woojh3690
Copy link

woojh3690 commented Nov 25, 2024

Same.
In Task Manager, two NVIDIA cicc processes are working, and the load is only on two CPUs.

  • AMD 7800x3D
  • Win 11.
  • ninja ver : 1.11.1.git.kitware.jobserver-1
  • python : 3.11.10

@AllenDou
Copy link

Actually there is no need to compile all file in setup.py

sources=[
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                ...
]

I will only keep necessary file, such as

sources=[
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
]

My gpu is L20, so I don't need compile for sm90, so I remove code below

    #if CUDA_HOME is not None:
    #    if bare_metal_version >= Version("11.8"):
    #        cc_flag.append("-gencode")
    #        cc_flag.append("arch=compute_90,code=sm_90")

Add these patch or it will case undefined symbol error

diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h
index a57702f..1636d15 100644
--- a/csrc/flash_attn/src/static_switch.h
+++ b/csrc/flash_attn/src/static_switch.h
@@ -17,13 +17,8 @@
 
 #define BOOL_SWITCH(COND, CONST_NAME, ...)      \
   [&] {                                         \
-    if (COND) {                                 \
-      constexpr static bool CONST_NAME = true;  \
-      return __VA_ARGS__();                     \
-    } else {                                    \
       constexpr static bool CONST_NAME = false; \
       return __VA_ARGS__();                     \
-    }                                           \
   }()
 
 #ifdef FLASHATTENTION_DISABLE_DROPOUT
@@ -78,37 +73,14 @@
 
 #define FP16_SWITCH(COND, ...)               \
   [&] {                                      \
-    if (COND) {                              \
       using elem_type = cutlass::half_t;     \
       return __VA_ARGS__();                  \
-    } else {                                 \
-      using elem_type = cutlass::bfloat16_t; \
-      return __VA_ARGS__();                  \
-    }                                        \
   }()
 
 #define HEADDIM_SWITCH(HEADDIM, ...)   \
   [&] {                                    \
-    if (HEADDIM <= 32) {                   \
-      constexpr static int kHeadDim = 32;  \
-      return __VA_ARGS__();                \
-    } else if (HEADDIM <= 64) {            \
+    if (HEADDIM <= 64) {            \
       constexpr static int kHeadDim = 64;  \
       return __VA_ARGS__();                \
-    } else if (HEADDIM <= 96) {            \
-      constexpr static int kHeadDim = 96;  \
-      return __VA_ARGS__();                \
-    } else if (HEADDIM <= 128) {           \
-      constexpr static int kHeadDim = 128; \
-      return __VA_ARGS__();                \
-    } else if (HEADDIM <= 160) {           \
-      constexpr static int kHeadDim = 160; \
-      return __VA_ARGS__();                \
-    } else if (HEADDIM <= 192) {           \
-      constexpr static int kHeadDim = 192; \
-      return __VA_ARGS__();                \
-    } else if (HEADDIM <= 256) {           \
-      constexpr static int kHeadDim = 256; \
-      return __VA_ARGS__();                \
     }                                      \
   }()

compile source code with

MAX_JOBS=8 NVCC_THREADS=1 time pip install -e . -v

It only cost less than 1min

and run

pytest tests/test_flash_attn.py -k test_flash_attn_kvcache -s

This way only work for research not production.

@AllenDou
Copy link

@tridao Could the flash-attn team support a dedicated compilation feature? For instance, we could specify HDIM=64 and DTYPE=float16 during installation (pip install -e . -v) to build a version only for head dimension 64 and torch.float16. This would greatly facilitate the development of flash-attn.

@tridao
Copy link
Member

tridao commented Dec 11, 2024

Sure, happy to review a PR!

@AllenDou
Copy link

@tridao I have created a PR, #1384. Could you please take a look at it when you have spare time:)

@bobma-resideo
Copy link

ninja+ MAX_JOBS=256 = Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests