Skip to content

Commit

Permalink
apis: add data proto to documentation page. use copy_to_local instead…
Browse files Browse the repository at this point in the history
… of copy_local_path_from_hdfs (#358)
  • Loading branch information
eric-haibin-lin authored Feb 26, 2025
1 parent efd0061 commit 2440aa6
Show file tree
Hide file tree
Showing 30 changed files with 146 additions and 58 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
- [Logic R1](https://github.com/Unakar/Logic-RL): a reproduced DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset.
- [deepscaler](https://github.com/agentica-project/deepscaler): iterative context scaling with GRPO
- [critic-rl](https://github.com/HKUNLP/critic-rl): Teaching Language Models to Critique via Reinforcement Learning
- [Easy-R1](https://github.com/hiyouga/EasyR1): Multi-Modality RL

## Contribution Guide
Contributions from the community are welcome!
Expand Down
6 changes: 3 additions & 3 deletions docs/README_vllm0.7.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Readme for verl(vllm>=0.7) version
# Upgrading to vllm >= 0.7

## Installation

Note: This version of veRL supports **FSDP** for training and **vLLM** for rollout. (Megatron-LM is not supported yet.)
Note: This version of veRL+vllm 0.7+ supports **FSDP** for training and **vLLM** for rollout.

```
# Create the conda environment
Expand Down Expand Up @@ -62,4 +62,4 @@ For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rol

1. **num_scheduler_step>1:** not supported yet (weight loading has not been aligned with `MultiStepModelRunner`)
2. **Prefix caching:** not supported yet (vLLM sleep mode does not support prefix caching)
3. **Chunked prefill:** supported
3. **Chunked prefill:** supported
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['recommonmark',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel',
]

Expand Down
59 changes: 59 additions & 0 deletions docs/data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
Data interface
=========================

DataProto is the interface for data exchange.

The :class:`verl.DataProto` class contains two key members:

- batch: a :class:`tensordict.TensorDict` object for the actual data
- meta_info: a :class:`Dict` with additional meta information

TensorDict
~~~~~~~~~~~~

:attr:`DataProto.batch` is built on top of :class:`tensordict`, a project in the PyTorch ecosystem.
A TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size.

.. code-block:: python
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 5)}, batch_size=[2,])
>>> tensordict["twos"] = 2 * torch.ones(2, 5, 6)
>>> zeros = tensordict["zeros"]
>>> tensordict
TensorDict(
fields={
ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
One can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well.

.. code-block:: python
>>> tensordict[..., :1]
TensorDict(
fields={
ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
>>> tensordict = tensordict.to("cuda:0")
>>> tensordict = tensordict.reshape(6)
For more about :class:`tensordict.TensorDict` usage, see the official tensordict_ documentation.

.. _tensordict: https://pytorch.org/tensordict/overview.html


Core APIs
~~~~~~~~~~~~~~~~~

.. autoclass:: verl.DataProto
:members: to, select, union, make_iterator, concat
8 changes: 8 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ verl is fast with:
:caption: Performance Tuning Guide

perf/perf_tuning
README_vllm0.7.md

.. toctree::
:maxdepth: 1
Expand All @@ -88,6 +89,13 @@ verl is fast with:
advance/fsdp_extension
advance/megatron_extension

.. toctree::
:maxdepth: 1
:caption: API References

data.rst


.. toctree::
:maxdepth: 1
:caption: FAQ
Expand Down
4 changes: 2 additions & 2 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(config):

@ray.remote
def main_task(config):
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
from transformers import AutoTokenizer

# print initial config
Expand All @@ -110,7 +110,7 @@ def main_task(config):
OmegaConf.resolve(config)

# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
local_path = copy_to_local(config.actor_rollout_ref.model.path)

# instantiate tokenizer
from verl.utils import hf_tokenizer
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/arithmetic_sequence/rl/main_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
from tests.e2e.envs.digit_completion import CharTokenizer


Expand Down Expand Up @@ -105,7 +105,7 @@ def main(config):
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values

# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
local_path = copy_to_local(config.actor_rollout_ref.model.path)
local_path = os.path.expanduser(local_path)
# instantiate tokenizern
tokenizer = AutoTokenizer.from_pretrained(local_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/rollout/run_fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def main():
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'

from verl.utils.fs import copy_local_path_from_hdfs
local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True)
with torch.device("cuda"):
Expand Down
4 changes: 2 additions & 2 deletions tests/rollout/test_vllm_hf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def test_vllm_with_hf():
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'deepseek-ai/deepseek-llm-7b-chat'
from verl.utils.fs import copy_local_path_from_hdfs
local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path)

preencode_prompts = [
Expand Down
4 changes: 2 additions & 2 deletions tests/rollout/test_vllm_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def test_vllm_spmd():
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
from verl.utils.fs import copy_local_path_from_hdfs
local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
from verl.utils.fs import copy_to_local
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')

preencode_prompts = [
Expand Down
4 changes: 4 additions & 0 deletions verl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@
import logging

set_basic_config(level=logging.WARNING)

from . import single_controller

__all__ = ['DataProto', "__version__"]
1 change: 1 addition & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def validate_input(keys):
def union(self, other: 'DataProto') -> 'DataProto':
"""Union with another DataProto. Union batch and meta_info separately.
Throw an error if
- there are conflict keys in batch and they are not equal
- the batch size of two data batch is not the same
- there are conflict keys in meta_info and they are not the same.
Expand Down
6 changes: 6 additions & 0 deletions verl/single_controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@

version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))

# Note(haibin.lin): single_controller.__version__ is deprecated
with open(os.path.join(os.path.join(version_folder, os.pardir), 'version/version')) as f:
__version__ = f.read().strip()

from . import base
from .base import *

__all__ = base.__all__
2 changes: 2 additions & 0 deletions verl/single_controller/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@

from .worker import Worker
from .worker_group import WorkerGroup, ClassWithInitArgs, ResourcePool

__all__ = ['Worker', 'WorkerGroup', 'ClassWithInitArgs', 'ResourcePool']
2 changes: 0 additions & 2 deletions verl/single_controller/base/megatron/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass
from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo


Expand Down
5 changes: 3 additions & 2 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import socket
from dataclasses import dataclass
from verl.single_controller.base.decorator import register, Dispatch, Execute
from .decorator import register, Dispatch, Execute


@dataclass
Expand Down Expand Up @@ -79,6 +79,7 @@ def to_dict(self):

# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
"""A (distributed) worker."""

def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
Expand Down Expand Up @@ -181,4 +182,4 @@ def execute_with_func_generator(self, func, *args, **kwargs):
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
def execute_func_rank_zero(self, func, *args, **kwargs):
result = func(*args, **kwargs)
return result
return result
4 changes: 3 additions & 1 deletion verl/single_controller/base/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import time
from typing import List, Any, Callable, Dict

from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn


class ResourcePool:
"""The resource pool with meta info such as world_size."""

def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
if process_on_nodes is None:
Expand Down Expand Up @@ -89,6 +90,7 @@ def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1)


class WorkerGroup:
"""A group of workers"""

def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False
Expand Down
3 changes: 1 addition & 2 deletions verl/single_controller/ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
from .megatron import (MegatronRayWorkerGroup, DistRankInfo, DistGlobalInfo)
from .base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls
6 changes: 3 additions & 3 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.dataset import SFTDataset
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
from verl.utils.tracking import Tracking
from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM
self.ulysses_device_mesh = ulysses_device_mesh
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# build tokenizer first
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
if self.config.data.chat_template is not None:
Expand Down Expand Up @@ -182,7 +182,7 @@ def _build_model_optimizer(self):
# TODO (zhangchi.usc1992):
# 1. support pretrain from random weights
# 2. support init directly from sharded weights
local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True)
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)

if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/main_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import hydra
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
from verl.utils.reward_score import math, gsm8k
import pandas as pd
import numpy as np
Expand All @@ -33,7 +33,7 @@ def select_reward_fn(data_source):

@hydra.main(config_path='config', config_name='evaluation', version_base=None)
def main(config):
local_path = copy_local_path_from_hdfs(config.data.path)
local_path = copy_to_local(config.data.path)
dataset = pd.read_parquet(local_path)
prompts = dataset[config.data.prompt_key]
responses = dataset[config.data.response_key]
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from transformers import AutoTokenizer

from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
Expand All @@ -42,7 +42,7 @@ def main(config):
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
local_path = copy_local_path_from_hdfs(config.model.path)
local_path = copy_to_local(config.model.path)
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)

Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def run_ppo(config, compute_score=None):

@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)

# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
local_path = copy_to_local(config.actor_rollout_ref.model.path)

# instantiate tokenizer
from verl.utils import hf_tokenizer
Expand Down
8 changes: 4 additions & 4 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp import ShardedStateDictConfig, ShardedOptimStateDictConfig

from verl.utils.fs import copy_local_path_from_hdfs, is_non_local
from verl.utils.fs import copy_to_local, is_non_local

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -59,9 +59,9 @@ def load_checkpoint(self, path=None, del_local_after_load=False, *args, **kwargs
print(
f'[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}'
)
local_model_path = copy_local_path_from_hdfs(remote_model_path)
local_optim_path = copy_local_path_from_hdfs(remote_optim_path)
local_extra_state_path = copy_local_path_from_hdfs(remote_extra_state_path)
local_model_path = copy_to_local(remote_model_path)
local_optim_path = copy_to_local(remote_optim_path)
local_extra_state_path = copy_to_local(remote_extra_state_path)

model_state_dict = torch.load(local_model_path)
optimizer_state_dict = torch.load(local_optim_path)
Expand Down
Loading

0 comments on commit 2440aa6

Please sign in to comment.