Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KTO Loss #475

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open

Add KTO Loss #475

wants to merge 25 commits into from

Conversation

hebiao064
Copy link
Collaborator

@hebiao064 hebiao064 commented Dec 13, 2024

Summary

Close KTO Item of the Roadmap: #371

Implements the Kahneman-Tversky Optimization (KTO) loss function.

KTO Loss Function

For a policy π compared to a reference policy π₀:

When y is chosen:

$L_{KTO} = 1 - \sigma(\beta \cdot (\log[\frac{\pi(x)}{\pi_0(x)}] - KL(\pi||\pi_0)_y))$

When y is rejected:

$L_{KTO} = 1 - \sigma(\beta \cdot (KL(\pi||\pi_0)_y - \log[\frac{\pi(x)}{\pi_0(x)}]))$

where:

  • σ is the sigmoid function
  • β is a temperature parameter
  • KL(π||π₀)_y is the KL divergence threshold for action y

Intuition

KTO loss is inspired by prospect theory from behavioral economics, which models how humans make decisions under uncertainty.

The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.

Screenshot 2024-12-13 at 11 10 39 AM

Credit by: https://www.youtube.com/watch?v=nSrj1J6ODoM&t=422s

Benchmark Result

Special thanks to @shivam15s on the optimization PR: #491, otherwise my implementation won't achieve speed as list below

Memory:

image

Speed:
image

Notable learning on optimizing the speed:

  • [Culprit] Repeated calculation of KL when we split to N chunks
  • [Good to have] Remove the unnecessary variables calculation like aux_outputs

Key Changes

  • Implemented LigerFusedLinearKTOLoss class
  • Added LigerFusedLinearKTOFunction for the core KTO computation
  • Created comprehensive test suite in test_kto_loss.py
  • Added reference implementation (HFKTOLoss) based on Hugging Face's implementation

Reference

Testing Done

Test is passing now:
pytest test/chunked_loss/test_kto_loss.py

  • Parameterized tests covering various configurations:
    • Different batch sizes, sequence lengths, hidden dims, and vocab sizes
    • Multiple data types (bfloat16, float32)
    • Bias and reference bias variations
    • Different ignore indices and beta values
  • Correctness tests comparing against reference implementation
  • Gradient checking and backward pass verification
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@hebiao064 hebiao064 marked this pull request as ready for review December 13, 2024 01:41
Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s

@hebiao064
Copy link
Collaborator Author

Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s

About KL, I'll take a further look in trl about how to support that.

About reduce, HF did averaged it here: loss = losses.nanmean()

hebiao064 and others added 11 commits December 16, 2024 21:34
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
### KTO LOSS
#### Memory

![image](https://github.com/user-attachments/assets/bd8fe4f6-0c18-4cf3-a79a-fc8634dcb492)
#### Speed

![image](https://github.com/user-attachments/assets/256cf0c3-3943-4f46-b256-38a577323a03)

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
@hebiao064 hebiao064 enabled auto-merge (squash) December 21, 2024 06:50
@hebiao064
Copy link
Collaborator Author

AMD Test failed due to no gpu available, not related to the PR: FAILED test/transformers/test_swiglu.py::test_correctness_functional[dtype1-10000.0-0.01-9-7-41] - RuntimeError: No HIP GPUs are available

Copy link
Collaborator

@kvignesh1420 kvignesh1420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Did a first pass over the functionality and left some comments.

H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like bias is being used as a boolean, whereas self.KTO_loss requires it to be an optional tensor. Can we modify the param names to avoid confusion?

preference_labels_chunk=None,
ref_input_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it *chunk_grad_bias and not chunk_grad_bias like the other gradients?

Comment on lines +144 to +146
input_chunk = input_chunk
ref_input_chunk = ref_input_chunk if use_ref_model else None
target_chunk = target_chunk
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reassignment can be avoided.

ref_input_chunk = ref_input_chunk if use_ref_model else None
target_chunk = target_chunk

# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's mention that the dimension 1 corresponds to sequence length.

"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this class already has a staticmethod for preference_loss_fn. Why do we need an extra arg here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants