Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update megatron-lm to core_r0.11.0 #392

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ISEEKYAN
Copy link

Support Megatron mcore 0.11

Description

This PR introduces official support for Megatron mcore 0.11 with the following updates:

  • Upgraded Megatron to version core_r0.11.0
  • Applied compatibility patch patches/mcore_r0.11.patch
  • Removed legacy version support for cleaner implementation

Special thanks to @chendong-1998 for:

  • Original Megatron upgrade from 0.4 to 0.6 (#93f6a7e)

Compatibility Notes

Current implementation requires careful handling due to dependency conflicts:

  • megatron-core==0.11.0 requires torch>=2.6
  • vllm==0.6.3 requires torch==2.4

Installation constraints:

  1. Must use vllm's torch dependency (2.4) as baseline
  2. Do NOT run pip install -e . in mcore directory (will upgrade torch to 2.6)
  3. Apply compatibility patch manually after installation

Testing

test with verl/examples/ppo_trainer/run_deepseek_megatron.sh

image

Chendong98 and others added 4 commits February 25, 2025 08:55
Signed-off-by: chendong-1998 <[email protected]>
Signed-off-by: chendong-1998 <[email protected]>
patch megatron-lm with `patches/mcore_r0.11.patch`
can't run `pip install -e .` in megatron directory, because mcore0.11 is
dependent on torch 2.6, but vLLM 0.6.3 requires torch 2.4
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could u keep the v0.4 patch file for now in case others are want to run v0.4 for comparison. thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove the v0.4 patch after the next stable release of verl

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK,I will add that back

@CLAassistant
Copy link

CLAassistant commented Feb 26, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Collaborator

@PeterSH6 PeterSH6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant work!



def get_model_config(model):
return get_attr_wrapped_model(model, 'megatron_config', allow_none=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think get_model_config() is no longer necessary if we change the customized model class element.
Currently in ParallelLlamaForCausalLMRmPadPP

  • config -> Huggingface config
  • megatron_config -> megatron.core.ModelParallelConfig

We could rename megatron_config to config and rename config to hf_config

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the multiple config types in ParallelLlamaForCausalLMRmPadPP can be confusing. However, directly renaming megatron_config to config and config to hf_config might introduce breaking changes to existing code that references these properties.
I suggest merging the current PR first, then handling the renaming separately with proper migration planning

@@ -216,7 +225,7 @@ class FakeTimers:
"""Disable All Megatron Timing with FakeTimers"""

def __init__(self):
from megatron.timers import DummyTimer
from megatron.core.timers import DummyTimer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need Timers in MCore v0.11?
This FakeTimer is mainly for optimizer.step() in MCore v0.4.
As there's no need to use timer in optimizer.step() in MCore v0.11, I suggest simply delete this class.
(Also delete its usage in L212)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, it works without the timer, I will delete the whole timer

self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)
if pp_rank == self._pp_rank:
from verl.utils.memory_buffer import MemoryBuffer
# The code here is very hard-coded, based on the following assumptions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. Just wondering if the MemoryBuffer will complicate the weight synchronization process when we enable EP?
If so, we can abandon the MemoryBuffer and change to per-parameter synchronization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to abandon MemoryBuffer to change to per-parameter synchronization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, weight binding can be generalized to per parameter all-gather and redistribute

@@ -139,7 +155,7 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):

def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:
print(f'megatron config {megatron_config}')
dt = PrecisionType.to_dtype(megatron_config['param_dtype'])
dt = torch.bfloat16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to pass the parameter dtype of Megatron in the current implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if modifying L22-L24 in verl/utils/torch_dtypes.py, like

HALF_LIST = [16, "16", "fp16", "float16", torch.float16]
FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32]
BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16]

then the args could be delivered by dt = PrecisionType.to_dtype(megatron_config.params_dtype)

from megatron.model import Float16Module
from megatron.model import DistributedDataParallel as LocalDDP

from megatron.training.utils import print_rank_0, unwrap_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We copied several module and util functions from the Megatron-LM package into the megatron_utils.py.
It would be better if we can remove importing from outside megatron.core

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, fixing this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants