forked from TransformerLensOrg/TransformerLens
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
149 lines (130 loc) · 5.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from dataclasses import dataclass
from typing import Optional
import torch
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
@dataclass
class HookedTransformerTrainConfig:
"""
Configuration class to store training hyperparameters for a training run of
an HookedTransformer model.
Args:
num_epochs (int): Number of epochs to train for
batch_size (int): Size of batches to use for training
lr (float): Learning rate to use for training
seed (int): Random seed to use for training
momentum (float): Momentum to use for training
max_grad_norm (float, *optional*): Maximum gradient norm to use for
weight_decay (float, *optional*): Weight decay to use for training
optimizer_name (str): The name of the optimizer to use
device (str, *optional*): Device to use for training
warmup_steps (int, *optional*): Number of warmup steps to use for training
save_every (int, *optional*): After how many batches should a checkpoint be saved
save_dir, (str, *optional*): Where to save checkpoints
wandb (bool): Whether to use Weights and Biases for logging
wandb_project (str, *optional*): Name of the Weights and Biases project to use
print_every (int, *optional*): Print the loss every n steps
max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
"""
num_epochs: int
batch_size: int
lr: float = 1e-3
seed: int = 0
momentum: float = 0.0
max_grad_norm: Optional[float] = None
weight_decay: Optional[float] = None
optimizer_name: str = "Adam"
device: Optional[str] = None
warmup_steps: int = 0
save_every: Optional[int] = None
save_dir: Optional[str] = None
wandb: bool = False
wandb_project_name: Optional[str] = None
print_every: Optional[int] = 50
max_steps: Optional[int] = None
def train(
model: HookedTransformer,
config: HookedTransformerTrainConfig,
dataset: Dataset,
) -> HookedTransformer:
"""
Trains an HookedTransformer model on an autoregressive language modeling task.
Args:
model: The model to train
config: The training configuration
dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
Returns:
The trained model
"""
torch.manual_seed(config.seed)
model.train()
if config.wandb:
if config.wandb_project_name is None:
config.wandb_project_name = "easy-transformer"
wandb.init(project=config.wandb_project_name, config=vars(config))
if config.device is None:
config.device = "cuda" if torch.cuda.is_available() else "cpu"
if config.optimizer_name in ["Adam", "AdamW"]:
# Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
if config.weight_decay is not None:
optimizer = optim.AdamW(
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
)
else:
optimizer = optim.Adam(
model.parameters(),
lr=config.lr,
)
elif config.optimizer_name == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay
if config.weight_decay is not None
else 0.0,
momentum=config.momentum,
)
else:
raise ValueError(f"Optimizer {config.optimizer_name} not supported")
scheduler = None
if config.warmup_steps > 0:
scheduler = optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
model.to(config.device)
for epoch in tqdm(range(1, config.num_epochs + 1)):
samples = 0
for step, batch in tqdm(enumerate(dataloader)):
tokens = batch["tokens"].to(config.device)
loss = model(tokens, return_type="loss")
loss.backward()
if config.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optimizer.step()
if config.warmup_steps > 0:
assert scheduler is not None
scheduler.step()
optimizer.zero_grad()
samples += tokens.shape[0]
if config.wandb:
wandb.log(
{"train_loss": loss.item(), "samples": samples, "epoch": epoch}
)
if config.print_every is not None and step % config.print_every == 0:
print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
if (
config.save_every is not None
and step % config.save_every == 0
and config.save_dir is not None
):
torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")
if config.max_steps is not None and step >= config.max_steps:
break
return model