-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathreinforcement_learn.py
153 lines (127 loc) · 4.98 KB
/
reinforcement_learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from torch.distributions import Categorical
from torch.optim import AdamW
import wandb
from tokens import PAD, SEQ_LENGTH, START
from board_ops import (
batch_check_winner,
batch_next_valid_move_from_seq,
batch_board_full,
batch_seq_to_board,
batch_detect_illegal_moves,
)
from setup import load_from_checkpoint, device, save_checkpoint
NUM_EPOCHS = 20000
BATCH_SIZE = 1024
LR = 1e-5
WANDB = True
if WANDB:
run = wandb.init(
project="ttt",
config={
"epochs": NUM_EPOCHS,
"batch_size": BATCH_SIZE,
"learning_rate": LR,
},
)
model = load_from_checkpoint()
model.eval()
model.to(device)
def compute_rewards(sequences):
illegal = batch_detect_illegal_moves(sequences[:, 1:])
valid_sequences = ~illegal
rewards = torch.full(
(sequences.shape[0],), 0.5, device=sequences.device, dtype=torch.float
)
rewards[illegal] = 0
if torch.any(valid_sequences):
boards = batch_seq_to_board(sequences[valid_sequences, 1:])
winners = torch.zeros(
sequences.shape[0], dtype=torch.long, device=sequences.device
)
full = torch.zeros(
sequences.shape[0], dtype=torch.bool, device=sequences.device
)
winners[valid_sequences] = batch_check_winner(boards)
full[valid_sequences] = batch_board_full(boards)
rewards[winners == 1] = 1
rewards[winners == -1] = 0
# incomplete games (due to an invalid move)
incomplete = winners == 0 & ~full
# print(incomplete)
rewards[incomplete] = -0.5
return rewards
optimizer = AdamW(model.parameters(), lr=LR)
input_ids = torch.tensor([START], dtype=torch.long, device=device)[None, ...]
for epoch in range(NUM_EPOCHS):
model.train()
output_ids = torch.full(
(BATCH_SIZE, SEQ_LENGTH), PAD, dtype=torch.long, device=device
)
output_ids[:, : input_ids.shape[1]] = input_ids
log_probs_accumulated = torch.zeros((BATCH_SIZE, 1), device=device)
entropy_accumulated = torch.zeros((BATCH_SIZE, 1), device=device)
# keep track of which games (within the batch) have completed
# when a game is complete there is an PAD token
# we must stop accumulating for that story
active_mask = torch.ones(BATCH_SIZE, dtype=torch.bool, device=device)
for i in range(input_ids.shape[1], SEQ_LENGTH):
# player 2 - RANDOM MOVE
if i % 2 == 0:
next_moves = batch_next_valid_move_from_seq(output_ids[active_mask, 1:])
output_ids[active_mask, i] = next_moves
assert torch.all(output_ids[active_mask, i] != PAD)
illegal = batch_detect_illegal_moves(output_ids[active_mask, 1:])
assert not torch.any(illegal)
# player 1 - AI MOVE
else:
prompt = output_ids[:, :i].clone()
logits = model(prompt)[0]
# Only consider logits of active sequences
logits_active = logits[active_mask]
if logits_active.shape[0] == 0:
# All sequences are finished
break
probs = torch.nn.functional.softmax(logits_active, dim=-1)
entropy_current = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)
entropy_accumulated[active_mask] += entropy_current
dist = Categorical(probs)
next_tokens = dist.sample()
log_probs_accumulated[active_mask] += dist.log_prob(next_tokens)
output_ids[active_mask, i] = next_tokens.T
invalid_move = output_ids[active_mask, i].squeeze(-1) == PAD
illegal = invalid_move | batch_detect_illegal_moves(output_ids[active_mask, 1:])
active_indices = torch.nonzero(active_mask).squeeze(-1)
active_mask = active_mask.clone()
active_mask[active_indices] &= ~illegal
if torch.any(active_mask):
boards = batch_seq_to_board(output_ids[active_mask, 1:])
winners = batch_check_winner(boards)
full = batch_board_full(boards)
active_indices = torch.nonzero(active_mask).squeeze(-1)
active_mask = active_mask.clone()
active_mask[active_indices] &= ~full
active_mask[active_indices] &= winners == 0
normalized_log_probs = log_probs_accumulated / SEQ_LENGTH
# Compute rewards for the entire batch
with torch.no_grad():
rewards = compute_rewards(output_ids)
# Compute loss for the entire batch
neg_advantage = (-normalized_log_probs * rewards.unsqueeze(-1)).mean()
alpha = 0.0 # hyperparameter to be tuned
average_entropy = entropy_accumulated.mean()
loss = neg_advantage - alpha * average_entropy
optimizer.zero_grad()
loss.backward()
optimizer.step()
if WANDB:
wandb.log(
{
"loss": loss,
"reward": rewards.mean(),
}
)
print(
f"Epoch {epoch + 1}/{NUM_EPOCHS}: Loss: {loss.item()} Rewards: {rewards.mean()} NegAdv: {neg_advantage}"
)
save_checkpoint(model)