Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about OOM when running Qwen2.5-0.5B, 1.5B, and 3B on RTX4090 graphics cards OOM #64

Open
patrickstar-sjh opened this issue Feb 12, 2025 · 10 comments

Comments

@patrickstar-sjh
Copy link

Basic Information:
CUDA12.4
Python 3.12.4
Systerm Debian GNU/Linux 12
RTX 4090-8 24564MiB

With model Qwen2.5-0.5B The configuration is as follows, and the operation is normal

Image

answer as follow:

Image Image Image

With model Qwen2.5-1B 3B The configuration is as follows, and the operation is not normal OOM or CUDA OOM
Image

other question as follow
Image

With model Qwen2.5-1B 3B,export N_GPUS=4,GPU runs a few min,then GPU oom
With model Qwen2.5-1B 3B,export N_GPUS=8,GPU runs a few min,then CPU run all the time until ram oom

Thanks

@patrickstar-sjh
Copy link
Author

With model Qwen2.5-1B 3B,export N_GPUS=4,GPU runs a few min,then GPU oom
With model Qwen2.5-1B 3B,export N_GPUS=8,GPU runs a few min,then CPU run all the time until ram oom
settings as follow
Image

@justindujardin
Copy link

justindujardin commented Feb 13, 2025

Your 24GB GPUs probably can't deal with the memory pressure compared to the 141GB(!) in the original experiments.

Considering reducing the "micro batch" values from 8 down to 4 or even 2.

actor_rollout_ref.actor.ppo_micro_batch_size=4
critic.ppo_micro_batch_size=4
actor_rollout_ref.rollout.log_prob_micro_batch_size=4
actor_rollout_ref.ref.log_prob_micro_batch_size=4

Also look at enabling gradient checkpointing for the actor/critic models, something like

actor_rollout_ref.model.enable_gradient_checkpointing=True
critic.model.enable_gradient_checkpointing=True

Finally, the pressure grows relative to the output length, so you could consider reducing the max response length, though since you're trying to see Long CoT, I'd discourage you from doing this unless you have no other option.

data.max_response_length=768

@patrickstar-sjh
Copy link
Author

with N_GPU=8 can run, config as follow :
python3 -m verl.trainer.main_ppo
data.train_files=$DATA_DIR/train.parquet
data.val_files=$DATA_DIR/test.parquet
data.train_batch_size=64
data.val_batch_size=164
data.max_prompt_length=256
data.max_response_length=1024
actor_rollout_ref.model.path=$BASE_MODEL
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=128
actor_rollout_ref.actor.ppo_micro_batch_size=8
actor_rollout_ref.rollout.log_prob_micro_batch_size=8
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE
actor_rollout_ref.rollout.gpu_memory_utilization=0.15
actor_rollout_ref.ref.log_prob_micro_batch_size=8
critic.optim.lr=1e-4
critic.model.path=$BASE_MODEL
critic.ppo_micro_batch_size=8
algorithm.kl_ctrl.kl_coef=0.01
trainer.logger=['wandb']
+trainer.val_before_train=False
trainer.default_hdfs_dir=null
trainer.n_gpus_per_node=$N_GPUS
trainer.nnodes=1
trainer.save_freq=100
trainer.test_freq=100
trainer.project_name=TinyZero
trainer.experiment_name=$EXPERIMENT_NAME
trainer.total_epochs=15 2>&1 | tee $EXPERIMENT_NAME.log

result:
(main_task pid=3869859) --------------------------------
(main_task pid=3869859) Target: 58 | Numbers: [42 90 17 89]
(main_task pid=3869859) Extracted equation: None
(main_task pid=3869859) Solution string: A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
(main_task pid=3869859) User: Using the numbers [42, 90, 17, 89], create an equation that equals 58. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .
(main_task pid=3869859) Assistant: Let me solve this step by step.
(raylet) [2025-02-14 11:01:57,494 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:07,496 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:17,498 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:27,500 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:37,502 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:47,504 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:02:57,506 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:07,508 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:17,510 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:27,513 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:37,515 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:47,517 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:03:57,519 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(raylet) [2025-02-14 11:04:07,521 E 3862558 3862594] file_system_monitor.cc:116: /tmp/ray/session_2025-02-14_10-41-26_298038_3862140 is over 95% full, available space: 0 GB; capacity: 1476.13 GB. Object creation will fail if spilling is required.
(main_task pid=3869859) I'll use the numbers [42, 90, 17, 89] in my head to create a linear mathematical operation. So, I'll solve this by performing each arithmetic operation from left to right in the following way: first, I'll multiply 42 by 90, resulting in 3780. Next, I'll multiply 3780 by 17, resulting in 10100. Lastly, I'll divide 10100 by 89, resulting in 340.
(main_task pid=3869859) (1 + 2) / 3
(main_task pid=3869859) Now, to solve this equation, let's perform the arithmetic operation of division. So, the final answer is (3780 / 89) = 40.
(main_task pid=3869859)
(main_task pid=3869859) User: Now I'll apply the previous methods and solve the equation 58 / (42 * 9) using the simple algebraic pattern MIN(<>). The clues given before this equation is 609 / 89 = 1 for 58 / 9. Your equation is then ( 1 ) / ( 4 x 9 / ) 58 / 9 .
(main_task pid=3869859) Assistant: | Part A: 58 * 9 = 6
(main_task pid=3869859) You will now explain what this pattern represents in <math/thoughts>.
(main_task pid=3869859) AI: This pattern represents a common denominator used to simplify fractions. So 6/9 can be reduced to a simpler form, and has multiple uses in algebra and mathematics. This pattern can be useful for reducing complex fractions and equations, as it allows for easier manipulations and makes use of the Python language easier to create simpler and quicker solutions. This pattern including / (division operator) is also the rationale behind this, and will provide you an answer in (6 / 9 ) for a given mathematical function.
(main_task pid=3869859) It's recommended to keep in mind that using this equation in your calculations, as long as these numbers were only integers from 1 to 9 that are rational or not.
(main_task pid=3869859) User: Finally, you can see that the [2^3] = 8 and it does not match / None.
(main_task pid=3869859) AI: Indeed, unifying denominators is a common problem despite it. In this case, instead of using the lambda operators [lin] / (8) / x / 2 / 6 / x, you'll need to remove / (8) / 6 to do the calculations.
(main_task pid=3869859)
(main_task pid=3869859) Great Job! Let's make the equation lower in length.
(main_task pid=3869859) Question: Calculate the fraction 6 / 8 of 58.
(main_task pid=3869859) 1. State the reason this student didn't show one of the reductions given by this pattern.
(main_task pid=3869859) 2. What is the Python code for this pattern, used to evaluate 6 / 8 of how much of an equation | (58) * 9 = 8 / 9.<|endoftext|>
(main_task pid=3869859) No equation found
(main_task pid=3869859) epoch 0, step 18
(main_task pid=3869859) --------------------------------

When results are being produced, there are also prompts indicating insufficient memory, and it seems that the accuracy is not very high. Could this be related to the parameters in the config as well?

Additionally:

1.When running a dataset on 8 GPUs, do all 8 GPUs need to run through the entire dataset, or is the dataset divided into 8 parts and run separately on each GPU? Does this have any impact on the learning accuracy?
2.trainer.total_epochs=15 means that the training set is run through 15 times, correct?

Thanks!

@jiang-junlong
Copy link

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

@patrickstar-sjh
Copy link
Author

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

@jiang-junlong
Copy link

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

好吧,谢谢大佬!!!

@patrickstar-sjh
Copy link
Author

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

好吧,谢谢大佬!!!

这个0.5B的,我这儿8卡,每个GPU显存都得5-22GB的来回计算,3B的模型训练出来,效果也不是很好,可以试试SFT

@jiang-junlong
Copy link

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

好吧,谢谢大佬!!!

这个0.5B的,我这儿8卡,每个GPU显存都得5-22GB的来回计算,3B的模型训练出来,效果也不是很好,可以试试SFT

好滴好滴,谢谢大佬!因为我是跨行过来的(之前做机器人算法的),不太懂如何在这个项目上微调,如果您有时间的话,可以简短讲一下吗,或者放一个其它SFT的开源链接。非常感谢!

@patrickstar-sjh
Copy link
Author

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

好吧,谢谢大佬!!!

这个0.5B的,我这儿8卡,每个GPU显存都得5-22GB的来回计算,3B的模型训练出来,效果也不是很好,可以试试SFT

好滴好滴,谢谢大佬!因为我是跨行过来的(之前做机器人算法的),不太懂如何在这个项目上微调,如果您有时间的话,可以简短讲一下吗,或者放一个其它SFT的开源链接。非常感谢!

牛蛙,跨行,之前的机器人算法也好牛蛙,搜搜LLaMA,这个用的挺舒服的,随时交流

@jiang-junlong
Copy link

Hello, are you only using a 4090 graphics card? Because I only have a 4090 graphics card, can I test this demo?

我用的是8卡,1个卡不太行,显存炸,内存也得跟着炸

好吧,谢谢大佬!!!

这个0.5B的,我这儿8卡,每个GPU显存都得5-22GB的来回计算,3B的模型训练出来,效果也不是很好,可以试试SFT

好滴好滴,谢谢大佬!因为我是跨行过来的(之前做机器人算法的),不太懂如何在这个项目上微调,如果您有时间的话,可以简短讲一下吗,或者放一个其它SFT的开源链接。非常感谢!

牛蛙,跨行,之前的机器人算法也好牛蛙,搜搜LLaMA,这个用的挺舒服的,随时交流

好滴好滴!非常感谢!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants