You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In knowledge distillation, it has better efficiency to add support for pre-computed logits/logprobs offline in teacher model beforehand. Rather than load and forward the teacher outputs inside the kernel.
I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.
So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.
While @shivam15s pointed out the concern regarding temperature scaled logprobs in here
Besides, @Tcc0403 suggested that log-space is the right way to go in here. For my understanding, I agree with this idea given temperature=1.
Sorry for the misleading question and late response. Passing logpbs is totally fine, it's actually better that it can avoid underflow issues in the log-space. Torch's KLDivLoss also expect inputs in the log-space, and the extra amount of calculation from softmax to logsoftmax shouldn't be an issue anyway. So if most APIs are expecting input as logpbs, then I think it's the way to go.
In my opinion, I think it's good to support offline forwarded value (e.g., logits) for teacher model beforehand. However, I’m unsure how we should support log_probs/probs as args in ditillation_loss_fn? Since multiple input vectors can yield the same output probabilities due to the normalization step, softmax is not invertible in a strict sense. In conclusion it's hard to scale on these values (after softmax) by temperature.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
## Summary
Addressed the part of issue raised in
#441
Moving the scale temperature outside the `distillation_loss_fn` is fine
as well. Keep the `loss_fn` simpler, and the rest can be handled in the
`forward` function beforehand. Thanks to the advice by @Tcc0403
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
---------
Signed-off-by: Austin Liu <[email protected]>
🚀 The feature, motivation and pitch
In knowledge distillation, it has better efficiency to add support for pre-computed
logits
/logprobs
offline in teacher model beforehand. Rather than load and forward the teacher outputs inside the kernel.Some other thoughts on using
logits
orlogprobs
?We scaled
temperature
here.As @winglian mentioned here.
While @shivam15s pointed out the concern regarding temperature scaled
logprobs
in hereBesides, @Tcc0403 suggested that
log-space
is the right way to go in here. For my understanding, I agree with this idea giventemperature=1
.In my opinion, I think it's good to support offline forwarded value (e.g.,
logits
) for teacher model beforehand. However, I’m unsure how we should supportlog_probs
/probs
as args inditillation_loss_fn
? Since multiple input vectors can yield the same output probabilities due to the normalization step,softmax
is not invertible in a strict sense. In conclusion it's hard to scale on these values (aftersoftmax
) bytemperature
.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: