Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LigerFusedLinearCrossEntropyLoss Causes Training Loss to Diverge After Reaching ~8 #512

Open
penghui-yang opened this issue Jan 4, 2025 · 7 comments

Comments

@penghui-yang
Copy link

🐛 Describe the bug

Description

When using LigerFusedLinearCrossEntropyLoss (Liger FLCE) from the Liger kernel to replace torch.nn.CrossEntropyLoss, the training loss becomes unstable and diverges after reaching a certain value (~8). In contrast, the loss computed using torch.nn.CrossEntropyLoss continues to decrease smoothly.

Expected Behavior

The loss computed with LigerFusedLinearCrossEntropyLoss should decrease similarly to torch.nn.CrossEntropyLoss without significant oscillations or divergence.

Observed Behavior

  • During the initial training phase, both loss functions exhibit similar behavior, and the loss decreases as expected.
  • When the loss computed with LigerFusedLinearCrossEntropyLoss reaches ~8, it becomes unstable, oscillates, and diverges, as shown in the attached graph.

Screenshots/Logs

Loss curve comparison (attached):

  • The orange curve shows the behavior with torch.nn.CrossEntropyLoss (stable).
  • The purple curve shows the behavior with LigerFusedLinearCrossEntropyLoss (unstable and divergent).

image

Additional Context

  • This issue appears to be related to gradient computation or numerical stability with LigerFusedLinearCrossEntropyLoss.
  • No hyperparameter changes were made between the two implementations.

Request for Assistance

  • Please investigate whether there are implementation issues with LigerFusedLinearCrossEntropyLoss.
  • Are there additional configurations or training parameters required to avoid instability?

Thank you for your assistance!

Reproduce

Code to Reproduce

Original compute_loss implementation (works as expected):

def compute_loss(self, hidden_states, labels):
    logits = self.lm_head(hidden_states).float()
    # Using torch.nn.CrossEntropyLoss for loss computation
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
    return loss

New compute_fused_loss implementation (causes instability):

def compute_fused_loss(self, hidden_states, labels):
    shift_hidden_states = hidden_states[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten tokens
    shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
    shift_labels = shift_labels.view(-1)

    lce = LigerFusedLinearCrossEntropyLoss(reduction="mean")
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
    return loss

Steps to Reproduce

  1. Replace the original compute_loss function with the new compute_fused_loss function using LigerFusedLinearCrossEntropyLoss.
  2. Train a model using both implementations (torch.nn.CrossEntropyLoss and LigerFusedLinearCrossEntropyLoss) for comparison.
  3. Observe the behavior of the loss curves during training.
    • With torch.nn.CrossEntropyLoss, the loss continues to decrease as expected.
    • With LigerFusedLinearCrossEntropyLoss, the loss starts to oscillate and then diverges when it reaches ~8.

Versions

Environment

  • Liger Kernel Version: 0.3.1
  • Hardware: 8 * A100 GPU
  • CUDA Version: 12.4
  • PyTorch Version: 2.5.1+cu124
  • Transformers Version: 4.46.3
  • Precision: torch.bfloat16
  • Optimizer: Zero Stage 1
@qingquansong
Copy link
Collaborator

@shivam15s if you happen to have time to take a look ^^ thank you!

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 4, 2025

Code to Reproduce
Original compute_loss implementation (works as expected):

def compute_loss(self, hidden_states, labels):
    logits = self.lm_head(hidden_states).float()
    # Using torch.nn.CrossEntropyLoss for loss computation
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
    return loss

Just to clarify, is logits[:, :-1].reshape(-1, logits.size(-1))
supposed to be logits[:, :-1, :].reshape(-1, logits.size(-1)) ?

@penghui-yang
Copy link
Author

Code to Reproduce
Original compute_loss implementation (works as expected):

def compute_loss(self, hidden_states, labels):
    logits = self.lm_head(hidden_states).float()
    # Using torch.nn.CrossEntropyLoss for loss computation
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels[:, 1:].reshape(-1))
    return loss

Just to clarify, is logits[:, :-1].reshape(-1, logits.size(-1)) supposed to be logits[:, :-1, :].reshape(-1, logits.size(-1)) ?

Yes, you are right. I'm sorry that my source code wasn't written very clearly.

@yzhangcs
Copy link

yzhangcs commented Jan 5, 2025

@penghui-yang Hi, trying to reduce the number of chunks might help you. I hear from my friend that this could help improve the stability.
Check out my adapted code, which fixed the number of chunks to 8.
https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/fused_linear_cross_entropy.py

@penghui-yang
Copy link
Author

Hi there,

I found the root cause of the issue with LigerFusedLinearCrossEntropyLoss. The problem was related to the data types of the weight and hidden_states variables. Initially, both were in fp16 when passed to the fused loss function. After converting these variables to fp32 before passing them to LigerFusedLinearCrossEntropyLoss, the issue was resolved, and the training loss started to converge normally just as using torch.nn.CrossEntropyLoss.

Here is the updated code that fixed the issue:

def compute_fused_loss(self, hidden_states, labels):
    shift_hidden_states = hidden_states[..., :-1, :].float().contiguous()
    shift_labels = labels[..., 1:].contiguous()
    lm_head_weight = self.lm_head.weight.float().contiguous()

    shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
    shift_labels = shift_labels.view(-1)

    lce = LigerFusedLinearCrossEntropyLoss(reduction="mean")
    loss = lce(lm_head_weight, shift_hidden_states, shift_labels)
    return loss

@yzhangcs Thanks a lot for your wonderful fla package! I found the previous solution when using the flce loss in your fla package because I found a similar problem of not converging and thought this must be my own mistake. After transferring these variables to fp32 before passing them to FusedLinearCrossEntropyLoss, the problem was solved as well.

However, the exploration raises a new question. When using standard torch.nn.CrossEntropyLoss, the conversion to fp32 typically happens after the linear layer, ensuring stability for loss computation. In my case, explicitly converting both weight and hidden_states to fp32 before calling LigerFusedLinearCrossEntropyLoss solved the problem. This raises the following question:

Is there room for further optimization here?

Specifically, could the fused loss function itself handle this type conversion internally (similar to torch.nn.CrossEntropyLoss), or should the user always perform this conversion manually to ensure stability?

Looking forward to your feedback and suggestions!

@yzhangcs
Copy link

yzhangcs commented Jan 6, 2025

@penghui-yang How about doing matmuls under tf32? I think this would reduce the accum errors.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 7, 2025

I think the main issue is the gradient calculation of weight. If the chunk size is too small, we might acquire small values that are out of the range fp16 can represent in each chunk. Plus, it would cause numerical instability when performing gradient accumulation in fp16.

Perhaps we should force dw calculation in fp32 to ensure numerical stability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants