-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LigerFusedLinearCrossEntropyLoss
Causes Training Loss to Diverge After Reaching ~8
#512
Comments
@shivam15s if you happen to have time to take a look ^^ thank you! |
Just to clarify, is |
Yes, you are right. I'm sorry that my source code wasn't written very clearly. |
@penghui-yang Hi, trying to reduce the number of chunks might help you. I hear from my friend that this could help improve the stability. |
Hi there, I found the root cause of the issue with Here is the updated code that fixed the issue: def compute_fused_loss(self, hidden_states, labels):
shift_hidden_states = hidden_states[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
lm_head_weight = self.lm_head.weight.float().contiguous()
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss(reduction="mean")
loss = lce(lm_head_weight, shift_hidden_states, shift_labels)
return loss @yzhangcs Thanks a lot for your wonderful However, the exploration raises a new question. When using standard Is there room for further optimization here? Specifically, could the fused loss function itself handle this type conversion internally (similar to Looking forward to your feedback and suggestions! |
@penghui-yang How about doing matmuls under tf32? I think this would reduce the accum errors. |
I think the main issue is the gradient calculation of weight. If the chunk size is too small, we might acquire small values that are out of the range fp16 can represent in each chunk. Plus, it would cause numerical instability when performing gradient accumulation in fp16. Perhaps we should force dw calculation in fp32 to ensure numerical stability. |
🐛 Describe the bug
Description
When using
LigerFusedLinearCrossEntropyLoss
(Liger FLCE) from the Liger kernel to replacetorch.nn.CrossEntropyLoss
, the training loss becomes unstable and diverges after reaching a certain value (~8). In contrast, the loss computed usingtorch.nn.CrossEntropyLoss
continues to decrease smoothly.Expected Behavior
The loss computed with
LigerFusedLinearCrossEntropyLoss
should decrease similarly totorch.nn.CrossEntropyLoss
without significant oscillations or divergence.Observed Behavior
LigerFusedLinearCrossEntropyLoss
reaches ~8, it becomes unstable, oscillates, and diverges, as shown in the attached graph.Screenshots/Logs
Loss curve comparison (attached):
torch.nn.CrossEntropyLoss
(stable).LigerFusedLinearCrossEntropyLoss
(unstable and divergent).Additional Context
LigerFusedLinearCrossEntropyLoss
.Request for Assistance
LigerFusedLinearCrossEntropyLoss
.Thank you for your assistance!
Reproduce
Code to Reproduce
Original
compute_loss
implementation (works as expected):New
compute_fused_loss
implementation (causes instability):Steps to Reproduce
compute_loss
function with the newcompute_fused_loss
function usingLigerFusedLinearCrossEntropyLoss
.torch.nn.CrossEntropyLoss
andLigerFusedLinearCrossEntropyLoss
) for comparison.torch.nn.CrossEntropyLoss
, the loss continues to decrease as expected.LigerFusedLinearCrossEntropyLoss
, the loss starts to oscillate and then diverges when it reaches ~8.Versions
Environment
0.3.1
8 * A100 GPU
12.4
2.5.1+cu124
4.46.3
torch.bfloat16
Zero Stage 1
The text was updated successfully, but these errors were encountered: