Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support offline logits for teacher model #441

Open
austin362667 opened this issue Dec 9, 2024 · 0 comments
Open

Support offline logits for teacher model #441

austin362667 opened this issue Dec 9, 2024 · 0 comments
Assignees

Comments

@austin362667
Copy link
Collaborator

austin362667 commented Dec 9, 2024

🚀 The feature, motivation and pitch

In knowledge distillation, it has better efficiency to add support for pre-computed logits/logprobs offline in teacher model beforehand. Rather than load and forward the teacher outputs inside the kernel.

Some other thoughts on using logits or logprobs?

We scaled temperature here.

As @winglian mentioned here.

I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.

So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.

While @shivam15s pointed out the concern regarding temperature scaled logprobs in here

curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174

Besides, @Tcc0403 suggested that log-space is the right way to go in here. For my understanding, I agree with this idea given temperature=1.

Sorry for the misleading question and late response. Passing logpbs is totally fine, it's actually better that it can avoid underflow issues in the log-space. Torch's KLDivLoss also expect inputs in the log-space, and the extra amount of calculation from softmax to logsoftmax shouldn't be an issue anyway. So if most APIs are expecting input as logpbs, then I think it's the way to go.

In my opinion, I think it's good to support offline forwarded value (e.g., logits) for teacher model beforehand. However, I’m unsure how we should support log_probs/probs as args in ditillation_loss_fn? Since multiple input vectors can yield the same output probabilities due to the normalization step, softmax is not invertible in a strict sense. In conclusion it's hard to scale on these values (after softmax) by temperature.

Alternatives

No response

Additional context

No response

@austin362667 austin362667 self-assigned this Dec 9, 2024
Tcc0403 pushed a commit that referenced this issue Jan 1, 2025
## Summary

Addressed the part of issue raised in
#441

Moving the scale temperature outside the `distillation_loss_fn` is fine
as well. Keep the `loss_fn` simpler, and the rest can be handled in the
`forward` function beforehand. Thanks to the advice by @Tcc0403

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant