Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and update vllm-based GRPO Trainer implementation #85

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

FUJIsyu0515
Copy link

In order to fix the problems in the original vllm-based grpo trainer implementation, a new trainer Qwen2VLGRPOVLLMTrainerModified is provided (located in src/open-r1-multimodal/src/open_r1/trainer/vllm_grpo_trainer_modified.py).

It no longer uses RepeatRandomSampler to avoid the issue of training steps doubling. Instead, it completes multiple sampling and loss calculations for each prompt within a single original batch, maintaining consistency with the logic of Qwen2VLGRPOTrainer. Additionally, it no longer requires num_generations to match the number of GPUs.

  • The new Trainer has replaced the original Qwen2VLGRPOVLLMTrainer in src/open-r1-multimodal/src/open_r1/grpo.py.
  • The vllm sampling logic has been corrected.
  • Multi-machine implementation is not yet completed (TODO).

# group into pairs
all_multimodal_inputs = []
for prompt, image in zip(all_prompts_text, all_images):
for _ in range(self.num_generations):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can utilize the n sampling params in vLLM to avoid this for loop?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TobiasLee
Copy link
Collaborator

here is my idea of avoiding for loop to repeat n_gen (which might be slow when num_generation becomes large compared to direct repeat sampler:

        if self.args.use_vllm:
            # previous code remains the same 

            # Generate completions using vLLM: gather all prompts and use them in a single call
            all_prompts_text = gather_object(prompts_text)
            all_images = gather_object(images)

            # prepare all inputs = global batch size 
            all_multimodal_inputs = [{"prompt": prompt, "multi_modal_data": {"image": image}} for prompt, image in zip(all_prompts_text, all_images)]

            # Create sampling params with num_generations
            if self.accelerator.is_main_process:
                # Clone to avoid modifying original params
                sampling_params = self.sampling_params.copy()
                sampling_params.n = self.num_generations
            else:
                sampling_params = None

            # Single generate call with all prompts
            if self.accelerator.is_main_process:
                outputs = self.llm.generate(
                    all_multimodal_inputs,
                    sampling_params=sampling_params,
                    use_tqdm=False,
                )
                # Flatten outputs: [prompt1_gen1, prompt1_gen2..., prompt2_gen1, prompt2_gen2...]
                completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
            else:
                completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations

            # [Keep the broadcasting and slicing logic unchanged...]
            completion_ids = broadcast_object_list(completion_ids, from_process=0)
            process_slice = slice(
                self.accelerator.process_index * len(prompts) * self.num_generations,
                (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
            )
            completion_ids = completion_ids[process_slice]

            # [Keep the padding and concatenation logic unchanged...]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants