You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to fine-tune the translation model using PPO with trl. The model is an encoder-decoder MarianMTModel. As a reward I use the score of COMET-QE (popular MT metric). The idea is very simple:
Generate N (5) translations for each source sentence in a batch using beam search or sampling
Score them with comet (value between 0 and 1)
Run PPO
After some tweaking of training parameters I was able to run trainings (the crucial change was setting "kl_penalty" to "full") and got some increase in the translation quality as measured by other metrics. The problem is, however, that gains are smaller than when I just use SFT on the best scored outputs.
I attach trl plots from one of the trainings that yielded decent results - it looks like the loss is blowing up after about 80 steps. There is also a plot of wmt22-comet-da, a metric I use for early stopping on a dev set. I experimented with different learning rates, increased init and target KL, different number of PPO epochs, but still, I cannot get my model beat SFT. Do you have an idea what can be done better?
The parameters that worked for me the best so far are:
{
"learning_rate": 5e-6, # (target lr after a 100-step warmup)
"gamma": 0.99,
"ppo_epochs": 4,
"is_encoder_decoder": true,
"batch_size": 256,
"mini_batch_size": 128,
"kl_penalty": "full",
}
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I am trying to fine-tune the translation model using PPO with trl. The model is an encoder-decoder MarianMTModel. As a reward I use the score of COMET-QE (popular MT metric). The idea is very simple:
After some tweaking of training parameters I was able to run trainings (the crucial change was setting "kl_penalty" to "full") and got some increase in the translation quality as measured by other metrics. The problem is, however, that gains are smaller than when I just use SFT on the best scored outputs.
I attach trl plots from one of the trainings that yielded decent results - it looks like the loss is blowing up after about 80 steps. There is also a plot of wmt22-comet-da, a metric I use for early stopping on a dev set. I experimented with different learning rates, increased init and target KL, different number of PPO epochs, but still, I cannot get my model beat SFT. Do you have an idea what can be done better?
The parameters that worked for me the best so far are:
{
"learning_rate": 5e-6, # (target lr after a 100-step warmup)
"gamma": 0.99,
"ppo_epochs": 4,
"is_encoder_decoder": true,
"batch_size": 256,
"mini_batch_size": 128,
"kl_penalty": "full",
}
Beta Was this translation helpful? Give feedback.
All reactions