Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] can't load saved fp8 checkpoint when resume training #1350

Open
switiz opened this issue Jan 8, 2025 · 0 comments
Open

[BUG] can't load saved fp8 checkpoint when resume training #1350

switiz opened this issue Jan 8, 2025 · 0 comments

Comments

@switiz
Copy link

switiz commented Jan 8, 2025

Describe the bug
we train small moe model using fp8 precision.
Excluding speed issues or convergence problems, saving and loading do not function properly.

I made some modifications to the experts.py file, specifically in the sharded_state_dict function.
In the loop:

for name, module in self._modules.items():
    if name in ['fp8_padding', 'fp8_unpadding']:
        continue

The fp8_padding and fp8_unpadding objects do not have a shared state dict, so I added a continue statement to skip them during the iteration.

To Reproduce
train using nemo framework with fp8 config, moe model and enable grouped gemm

  dist_ckpt_format: torch_dist
  dist_ckpt_parallel_save: true
  scale_positional_embedding: true
  restore_from_path: null
  dist_ckpt_load_strictness: log_all
  moe_router_topk: 8
  num_moe_experts: 64
  moe_token_dispatcher_type: alltoall
  moe_aux_loss_coeff: 0.01
  moe_z_loss_coeff: 0.001
  moe_router_load_balancing_type: aux_loss
  mcore_gpt: true
  moe_grouped_gemm: true
  micro_batch_size: 2
  global_batch_size: 2048
  rampup_batch_size: null
  tensor_model_parallel_size: 1
  pipeline_model_parallel_size: 1
  expert_model_parallel_size: 1
  virtual_pipeline_model_parallel_size: null
  encoder_seq_length: 8192
  max_position_embeddings: 8192
  num_layers: 16
  hidden_size: 2048
  ffn_hidden_size: 1024
  num_attention_heads: 32
  num_query_groups: 8
  init_method_std: 0.01
  use_scaled_init_method: true
  hidden_dropout: 0.0
  attention_dropout: 0.0
  ffn_dropout: 0.0
  kv_channels: null
  apply_query_key_layer_scaling: true
  normalization: rmsnorm
  layernorm_epsilon: 1.0e-05
  do_layer_norm_weight_decay: false
  make_vocab_size_divisible_by: 128
  pre_process: true
  post_process: true
  persist_layer_norm: true
  bias: false
  activation: fast-swiglu
  headscale: false
  transformer_block_type: pre_ln
  openai_gelu: false
  normalize_attention_scores: true
  position_embedding_type: rope
  rotary_percentage: 1.0
  rotary_base: 500000
  apply_rope_fusion: true
  attention_type: multihead
  share_embeddings_and_output_weights: false
  tokenizer:
    library: huggingface
    type: {path}
    use_fast: true
  native_amp_init_scale: 4294967296
  native_amp_growth_interval: 1000
  hysteresis: 2
  fp32_residual_connection: false
  fp16_lm_cross_entropy: false
  megatron_amp_O2: true
  grad_allreduce_chunk_size_mb: 125
  grad_div_ar_fusion: true
  gradient_accumulation_fusion: true
  bias_activation_fusion: true
  bias_dropout_add_fusion: true
  masked_softmax_fusion: true
  seed: 1234
  resume_from_checkpoint: null
  use_cpu_initialization: false
  onnx_safe: false
  apex_transformer_log_level: 30
  gradient_as_bucket_view: true
  sync_batch_comm: false
  activations_checkpoint_granularity: null
  activations_checkpoint_method: null
  activations_checkpoint_num_layers: null
  num_micro_batches_with_partial_activation_checkpoints: null
  activations_checkpoint_layers_per_pipeline: null
  sequence_parallel: false
  transformer_engine: true
  activation_func_fp8_input_store: true
  fp8_params: true
  fp8: true
  fp8_e4m3: false
  fp8_hybrid: true
  fp8_margin: 0
  fp8_interval: 1
  fp8_amax_history_len: 1024
  fp8_amax_compute_algo: max
  reduce_amax: true
  use_emha: false
  ub_tp_comm_overlap: false
  ub_tp_comm_overlap_cfg: null
  overlap_p2p_comm: false
  batch_p2p_comm: false
  seq_len_interpolation_factor: null
  use_flash_attention: true
  optim:
    name: mcore_distributed_optim
    lr: 0.0005
    weight_decay: 0.1
    betas:
    - 0.9
    - 0.95
    dtype: bf16
    overlap_grad_sync: true
    overlap_param_sync: true
    grad_sync_dtype: bf16

Expected behavior
can load saved fp8 checkpoint and resume training

Stack trace/logs

  1. save case

  0: [NeMo W 2025-01-08 23:55:02 validation:389] There is difference in the common state dict in different ranks. The differences are {1: ([], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 2: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 3: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 4: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 5: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 6: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>,
  0:  <class 'float'>)]), 7: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 8: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 9: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 10: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 11: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 12: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 13: ([('optimizer_stat
  0: es', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 14: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 15: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 16: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 17: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 18: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 19: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 
  0: 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 20: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 21: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 22: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 23: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 24: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 25: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', '
  0: time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 26: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 27: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 28: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 29: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 30: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 31: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>
  0: , <class 'float'>)]), 32: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 33: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 34: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 35: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 36: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 37: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 38: ([('optimizer_
  0: states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 39: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 40: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 41: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 42: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 43: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 44: ([('optimizer_states', 0, 'optimizer', 'param_groups',
  0:  1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 45: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 46: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 47: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 48: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 49: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 50: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer
  0: ', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 51: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 52: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 53: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 54: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 55: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 56: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'flo
  0: at'>, <class 'float'>)]), 57: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 58: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 59: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 60: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 61: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 62: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 63: ([('optimi
  0: zer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 64: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 65: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 66: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 67: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 68: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 69: ([('optimizer_states', 0, 'optimizer', 'param_grou
  0: ps', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 70: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 71: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 72: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 73: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 74: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 75: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'T
  0: imer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 76: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 77: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 78: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 79: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 80: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 81: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 
  0: 'float'>, <class 'float'>)]), 82: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 83: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 84: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 85: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 86: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 87: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 88: ([('op
  0: timizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 89: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 90: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 91: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 92: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 93: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 94: ([('optimizer_states', 0, 'optimizer', 'param_
  0: groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 95: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 96: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 97: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 98: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 99: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 100: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks
  0: ', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 101: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 102: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 103: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 104: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 105: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 106: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train
  0: '), <class 'float'>, <class 'float'>)]), 107: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 108: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 109: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 110: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 111: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 112: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float
  0: '>)]), 113: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 114: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 115: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 116: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 117: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 118: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 119: ([('optimizer_states',
  0:  0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 120: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 121: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 122: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 123: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 124: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 125: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1
  0: , 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 126: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)]), 127: ([('optimizer_states', 0, 'optimizer', 'param_groups', 1, 'step')], [], [(('callbacks', 'Timer', 'time_elapsed', 'train'), <class 'float'>, <class 'float'>)])}
  1. load case
 52: Error executing job with overrides: []
 52: Traceback (most recent call last):
 52:   File "/opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py", line 66, in main
 52:     trainer.fit(model)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
 52:     call._call_and_handle_interrupt(
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
 52:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
 52:     return function(*args, **kwargs)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
 52:     self._run(model, ckpt_path=ckpt_path)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 968, in _run
 52:     self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 398, in _restore_modules_and_callbacks
 52:     self.restore_model()
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 272, in restore_model
 52:     call._call_lightning_module_hook(self.trainer, "on_load_checkpoint", self._loaded_checkpoint)
 52:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
 52:     output = fn(*args, **kwargs)
 52:   File "/opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1996, in on_load_checkpoint
 52:     module.load_state_dict(checkpoint_state_dict, strict=True)
 52:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2564, in load_state_dict
 52:     load(self, state_dict)
 52:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
 52:     load(child, child_state_dict, child_prefix)  # noqa: F821
 52:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
 52:     load(child, child_state_dict, child_prefix)  # noqa: F821
 52:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2552, in load
 52:     load(child, child_state_dict, child_prefix)  # noqa: F821
 52:   [Previous line repeated 3 more times]
 52:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2535, in load
 52:     module._load_from_state_dict(
 52:   File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 1104, in _load_from_state_dict
 52:     self.set_extra_state(state_dict[extra_state_key])
 52:   File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 645, in set_extra_state
 52:     self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
 52: RuntimeError: The size of tensor a (192) must match the size of tensor b (3) at non-singleton dimension 0

Environment (please complete the following information):

  • NGC nemo docker 24.12
  • transformer_engine 1.13
  • Megatron-LM commit ID 1ce944c

Proposed fix

Additional context
Add any other context about the problem here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant