You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
we train small moe model using fp8 precision.
Excluding speed issues or convergence problems, saving and loading do not function properly.
I made some modifications to the experts.py file, specifically in the sharded_state_dict function.
In the loop:
for name, module in self._modules.items():
if name in ['fp8_padding', 'fp8_unpadding']:
continue
The fp8_padding and fp8_unpadding objects do not have a shared state dict, so I added a continue statement to skip them during the iteration.
To Reproduce
train using nemo framework with fp8 config, moe model and enable grouped gemm
52: Error executing job with overrides: []
52: Traceback (most recent call last):
52: File "/opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py", line 66, in main
52: trainer.fit(model)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
52: call._call_and_handle_interrupt(
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
52: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
52: return function(*args, **kwargs)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
52: self._run(model, ckpt_path=ckpt_path)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 968, in _run
52: self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 398, in _restore_modules_and_callbacks
52: self.restore_model()
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 272, in restore_model
52: call._call_lightning_module_hook(self.trainer, "on_load_checkpoint", self._loaded_checkpoint)
52: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
52: output = fn(*args, **kwargs)
52: File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1996, in on_load_checkpoint
52: module.load_state_dict(checkpoint_state_dict, strict=True)
52: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2564, in load_state_dict
52: load(self, state_dict)
52: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
52: load(child, child_state_dict, child_prefix) # noqa: F821
52: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
52: load(child, child_state_dict, child_prefix) # noqa: F821
52: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
52: load(child, child_state_dict, child_prefix) # noqa: F821
52: [Previous line repeated 3 more times]
52: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2535, in load
52: module._load_from_state_dict(
52: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 1104, in _load_from_state_dict
52: self.set_extra_state(state_dict[extra_state_key])
52: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 645, in set_extra_state
52: self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
52: RuntimeError: The size of tensor a (192) must match the size of tensor b (3) at non-singleton dimension 0
Environment (please complete the following information):
Describe the bug
we train small moe model using fp8 precision.
Excluding speed issues or convergence problems, saving and loading do not function properly.
I made some modifications to the experts.py file, specifically in the sharded_state_dict function.
In the loop:
The fp8_padding and fp8_unpadding objects do not have a shared state dict, so I added a continue statement to skip them during the iteration.
To Reproduce
train using nemo framework with fp8 config, moe model and enable grouped gemm
Expected behavior
can load saved fp8 checkpoint and resume training
Stack trace/logs
Environment (please complete the following information):
Proposed fix
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: