Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing on H100, 4 GPUS #54

Open
karrtikiyer-tw opened this issue Feb 9, 2025 · 5 comments
Open

Reproducing on H100, 4 GPUS #54

karrtikiyer-tw opened this issue Feb 9, 2025 · 5 comments

Comments

@karrtikiyer-tw
Copy link

karrtikiyer-tw commented Feb 9, 2025

We are trying to reproduce this experiment on 4 H100 GPU's 80 GB each. And we have not modified any code so far.
The training seems to be going now for 4+ days. We see below message in the logs:
�[36m(WorkerDict pid=1874660)�[0m Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForTokenClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the with torch.autocast(device_type='torch_device'):decorator, or load the model with thetorch_dtypeargument. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)�[36m(WorkerDict pid=1874660)�[0m You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU withmodel.to('cuda').

Is this a cause of concern?

The base model we are using is Qwen/Qwen2.5-3B

@VincentXWD
Copy link

same concern.

@karrtikiyer-tw
Copy link
Author

I ended up doing below changes to fix these messages related to Flash Attention on H100:
`# Change 1: In _build_model_optimizer method

Replace this block:

torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)

With:

torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.bfloat16 # Always use bfloat16 for Flash Attention 2.0
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)

Change 2: In the model initialization block

Replace:

actor_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code
)
actor_module.to(torch_dtype) # Remove this line

With:

actor_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
device_map="auto", # Add this line
trust_remote_code=trust_remote_code
)

Change 3: In _build_critic_model_optimizer method

Replace:

critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=critic_model_config,
torch_dtype=torch_dtype,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code
)
critic_module.to(torch_dtype) # Remove this line

With:

critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=critic_model_config,
torch_dtype=torch_dtype,
attn_implementation='flash_attention_2',
device_map="auto", # Add this line
trust_remote_code=trust_remote_code
)`

@AkaliKong
Copy link

Hello, I encountered the same issue as you. Could you please provide the specific file names corresponding to the functions that need to be modified? Thank you very much.

I ended up doing below changes to fix these messages related to Flash Attention on H100: `# Change 1: In _build_model_optimizer method

Replace this block:

torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype)

With:

torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: torch_dtype = torch.bfloat16 # Always use bfloat16 for Flash Attention 2.0 else: torch_dtype = PrecisionType.to_dtype(torch_dtype)

Change 2: In the model initialization block

Replace:

actor_module = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code ) actor_module.to(torch_dtype) # Remove this line

With:

actor_module = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, attn_implementation='flash_attention_2', device_map="auto", # Add this line trust_remote_code=trust_remote_code )

Change 3: In _build_critic_model_optimizer method

Replace:

critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=critic_model_config, torch_dtype=torch_dtype, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code ) critic_module.to(torch_dtype) # Remove this line

With:

critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=critic_model_config, torch_dtype=torch_dtype, attn_implementation='flash_attention_2', device_map="auto", # Add this line trust_remote_code=trust_remote_code )`

@karrtikiyer-tw
Copy link
Author

Hello, I encountered the same issue as you. Could you please provide the specific file names corresponding to the functions that need to be modified? Thank you very much.

I ended up doing below changes to fix these messages related to Flash Attention on H100: `# Change 1: In _build_model_optimizer method

Replace this block:

torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype)

With:

torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: torch_dtype = torch.bfloat16 # Always use bfloat16 for Flash Attention 2.0 else: torch_dtype = PrecisionType.to_dtype(torch_dtype)

Change 2: In the model initialization block

Replace:

actor_module = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code ) actor_module.to(torch_dtype) # Remove this line

With:

actor_module = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, attn_implementation='flash_attention_2', device_map="auto", # Add this line trust_remote_code=trust_remote_code )

Change 3: In _build_critic_model_optimizer method

Replace:

critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=critic_model_config, torch_dtype=torch_dtype, attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code ) critic_module.to(torch_dtype) # Remove this line

With:

critic_module = AutoModelForTokenClassification.from_pretrained( pretrained_model_name_or_path=local_path, config=critic_model_config, torch_dtype=torch_dtype, attn_implementation='flash_attention_2', device_map="auto", # Add this line trust_remote_code=trust_remote_code )`

The code I shared is of the file /TinyZero/verl/workers/fsdp_workers.py. I modified this code as explained above.

@richardzhuang0412
Copy link

Following up on this, I also get the following warning:
Some weights of Qwen2ForTokenClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-3B and are newly initialized: ['scor e.bias', 'score.weight']
Is this expected and normal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants