Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BearBiscuit05 committed Feb 25, 2025
1 parent a698038 commit b68b27e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def generate_sequences(self, prompts: DataProto):

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def chat(self, prompts: DataProto):
# copy from generate_sequences func
prompts = prompts.to('cuda')

assert self._is_rollout
Expand Down Expand Up @@ -513,6 +514,7 @@ def chat(self, prompts: DataProto):
log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)

prompts = self.rollout_sharding_manager.preprocess_data(prompts)
# change this func but get same return
output = self.rollout.chat(prompts=prompts)

log_gpu_memory_usage('After rollout generation', logger=logger)
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,6 @@ def chat(self, prompts: DataProto, **kwargs) -> DataProto:
}
prompt_data.meta_info.update(meta_info)
prompt_data.to('cuda')

# use generate_sequences to get return
return self.generate_sequences(prompt_data, **kwargs)

0 comments on commit b68b27e

Please sign in to comment.