-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathloss.py
65 lines (50 loc) · 1.77 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Optional
import torch
import torch.nn as nn
from replay_buffer import Experience
def approx_kl_divergence(
log_probs: torch.Tensor,
log_probs_ref: torch.Tensor,
action_mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Monte-Carlo approximation of KL divergence, k3 estimator, see: http://joschu.net/blog/kl-approx.html
"""
log_ratio = log_probs_ref.float() - log_probs.float()
if action_mask is not None:
log_ratio = log_ratio * action_mask
return log_ratio.exp() - log_ratio - 1
def masked_mean(
tensor: torch.Tensor,
mask: Optional[torch.Tensor],
dim: int = None,
) -> torch.Tensor:
if mask is None:
return tensor.mean(axis=dim)
return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim)
class GRPOLoss(nn.Module):
"""GRPO actor loss"""
def __init__(self, clip_eps: float, kl_weight: float) -> None:
super().__init__()
self.clip_eps = clip_eps
self.kl_weight = kl_weight
def forward(
self,
log_probs: torch.Tensor,
experience: Experience,
) -> tuple[torch.Tensor, torch.Tensor]:
old_log_probs = experience.action_log_probs
log_probs_ref = experience.log_probs_ref
action_mask = experience.action_mask
advantages = experience.advantages
kl = approx_kl_divergence(
log_probs=log_probs,
log_probs_ref=log_probs_ref,
action_mask=action_mask,
)
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) + self.kl_weight * kl
loss = masked_mean(loss, action_mask, dim=-1).mean()
return loss, kl.mean()