-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproducing on H100, 4 GPUS #54
Comments
same concern. |
I ended up doing below changes to fix these messages related to Flash Attention on H100: Replace this block:torch_dtype = fsdp_config.get('model_dtype', None) With:torch_dtype = fsdp_config.get('model_dtype', None) Change 2: In the model initialization blockReplace:actor_module = AutoModelForCausalLM.from_pretrained( With:actor_module = AutoModelForCausalLM.from_pretrained( Change 3: In _build_critic_model_optimizer methodReplace:critic_module = AutoModelForTokenClassification.from_pretrained( With:critic_module = AutoModelForTokenClassification.from_pretrained( |
Hello, I encountered the same issue as you. Could you please provide the specific file names corresponding to the functions that need to be modified? Thank you very much.
|
The code I shared is of the file /TinyZero/verl/workers/fsdp_workers.py. I modified this code as explained above. |
Following up on this, I also get the following warning: |
We are trying to reproduce this experiment on 4 H100 GPU's 80 GB each. And we have not modified any code so far.
The training seems to be going now for 4+ days. We see below message in the logs:
�[36m(WorkerDict pid=1874660)�[0m Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForTokenClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the
with torch.autocast(device_type='torch_device'):decorator, or load the model with the
torch_dtypeargument. Example:
model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)�[36m(WorkerDict pid=1874660)�[0m You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with
model.to('cuda').
Is this a cause of concern?
The base model we are using is
Qwen/Qwen2.5-3B
The text was updated successfully, but these errors were encountered: