Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of fixes to linting, wandb api usage, readme #207

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c973442
Add conda instructions to README.md
naveenarun Jan 26, 2024
c39070a
Allow training notebook to run without logger or wandb config
naveenarun Jan 26, 2024
4d56e31
lint and add unit test for training without wandb
naveenarun Jan 27, 2024
08f60a1
Ran black v24.1.0 on maze_transformers/ and tests/
naveenarun Jan 27, 2024
abd6ecf
cleaned up gitignore
mivanit Jan 28, 2024
d9a849c
update poetry.lock
mivanit Jan 28, 2024
2a42383
Merge pull request #206 from naveenarun/naveenarun-205-black-v24
mivanit Jan 28, 2024
70d494c
Merge pull request #204 from naveenarun/naveenarun-train-nowandb
mivanit Jan 28, 2024
b799a92
Merge pull request #203 from naveenarun/naveenarun-readme-conda
mivanit Jan 28, 2024
8d3916a
old cosmetic change to mazeplot attention rendering
mivanit Jan 28, 2024
4f82556
Merge branch 'misc-code-cleanup' of https://github.com/understanding-…
mivanit Jan 28, 2024
4ef0432
bump black to ^24.1 per #205
mivanit Jan 28, 2024
bbf4d83
reverting config to test in train_model notebook
mivanit Jan 28, 2024
f877656
revert black version for easier diffs, temporary
mivanit Jan 28, 2024
4c3f56d
running black on older version
mivanit Jan 28, 2024
9f5aa1d
Revert "Ran black v24.1.0 on maze_transformers/ and tests/"
mivanit Jan 28, 2024
8ee589d
reverting logger changes to training.py
mivanit Jan 29, 2024
bdd9034
wip, junk, needs review!
mivanit Feb 20, 2024
a0cfe7b
merge updates from main
mivanit Jul 26, 2024
a4088e6
attempting to get things running
mivanit Jul 26, 2024
7a635cc
run format
mivanit Jul 26, 2024
43386d1
fixes to StubLogger, integration tests pass
mivanit Jul 26, 2024
96d3e4e
skipping tests but whatever
mivanit Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# misc
wandb
.vscode/

# generated data/plots
data/**
notebooks/data/**
models/
wandb
tests/_temp/**
tests/**/_temp/**
notebooks/data/**
notebooks/figures/**
notebooks/plots/**
notebooks/figures/**

# testing temp files
tests/_temp/**
tests/**/_temp/**

# coverage
.coverage
htmlcov/

Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ Most of the functionality is demonstrated in the ipython notebooks in the `noteb
* Restart VSCode
* In VSCode, select the python interpreter located in `maze-transformer/.venv/bin` as your juptyer kernel

## Instructions for Conda users

* Create a new Conda environment: `conda create -n mazetransformer python=3.10 poetry`
* Activate the environment: `conda activate mazetransformer`
* Update poetry and install dev dependencies
```
poetry self update
poetry config virtualenvs.in-project true
poetry install --with dev
```
* Run unit, integration, and notebook tests
```
make test
```

## Testing & Static analysis

Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/logit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def logit_diff_residual_stream(
)
# get embedding of answer tokens
answer_residual_directions = vocab_residual_directions[tokens_correct]
# get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
# get the directional difference between logits and correct and logits on {all other tokens, comparison tokens}
logit_diff_directions: Float[torch.Tensor, "samples d_model"]
if tokens_compare_to is None:
logit_diff_directions = (
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def mazeplot_attention(
mazeplot.cbar_ax.set_position([pos.x0, new_y0, pos.width, new_height])
# add a title to the colorbar, vertically and to the side
mazeplot.cbar_ax.text(
5.0,
6.0,
0.5,
"Attention",
rotation=90,
Expand Down
13 changes: 8 additions & 5 deletions maze_transformer/test_helpers/stub_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,29 @@ def __init__(self):
def _log(self, *logs):
self.logs.append(logs)

def __call__(self, *args, **kwargs) -> None:
self._log("StubLogger.__call__ called", args, kwargs)

@classmethod
def create(cls, *args, **kwargs) -> "StubLogger":
logger = StubLogger()
logger._log("StubLogger created", args, kwargs)
return logger

def upload_model(self, *args, **kwargs) -> None:
self._log("Model uploaded.", args, kwargs)
self._log("StubLogger.upload_model called", args, kwargs)

def upload_dataset(self, *args, **kwargs) -> None:
self._log("Dataset uploaded.", args, kwargs)
self._log("StubLogger.upload_dataset called", args, kwargs)

def log_metric(self, *args, **kwargs) -> None:
self._log("Metric logged.", args, kwargs)
self._log("StubLogger.log_metric called", args, kwargs)

def log_metric_hist(self, *args, **kwargs) -> None:
self._log("Metric (Statcounter) logged.", args, kwargs)
self._log("StubLogger.log_metric_hist called", args, kwargs)

def summary(self, *args, **kwargs) -> None:
self._log("Summary logged.", args, kwargs)
self._log("StubLogger.summary called", args, kwargs)

def progress(self, message: str) -> None:
msg: str = f"[INFO] - {message}"
Expand Down
78 changes: 48 additions & 30 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
from muutils.mlutils import get_device
from muutils.mlutils import get_device, pprint_summary
from torch.utils.data import DataLoader

from maze_transformer.training.config import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def __str__(self):

def train_model(
base_path: str | Path,
wandb_project: Union[WandbProject, str],
wandb_project: Union[WandbProject, str] | None,
cfg: ConfigHolder | None = None,
cfg_file: str | Path | None = None,
cfg_names: typing.Sequence[str] | None = None,
Expand All @@ -60,6 +60,8 @@ def train_model(
- model config names: {model_cfg_names}
- train config names: {train_cfg_names}
"""
USES_LOGGER: bool = wandb_project is not None

if help:
print(train_model.__doc__)
return
Expand All @@ -85,26 +87,43 @@ def train_model(
(output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True)

# set up logger
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
logger_cfg_dict = dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},
)
logger.progress("Initialized logger")
logger.summary(
dict(
logger_cfg={
"output_dir": output_path.as_posix(),
"cfg.name": cfg.name,
"data_cfg.name": cfg.dataset_cfg.name,
"train_cfg.name": cfg.train_cfg.name,
"model_cfg.name": cfg.model_cfg.name,
"cfg_summary": cfg.summary(),
"cfg": cfg.serialize(),
},

# Set up logger if wanb project is specified
if USES_LOGGER:
logger: WandbLogger = WandbLogger.create(
config=cfg.serialize(),
project=wandb_project,
job_type=WandbJobType.TRAIN_MODEL,
)
)
logger.progress("Summary logged, getting dataset")
logger.progress("Initialized logger")
else:
logger = None

def log(msg: str | dict, log_type: str = "progress", **kwargs):
# Convenience function to let training routine work whether or not
# logger exists
if logger:
log_fn = getattr(logger, log_type)
log_fn(msg, **kwargs)
else:
if type(msg) == dict:
pprint_summary(msg)
else:
print(msg)

log(logger_cfg_dict, log_type="summary")
log("Summary logged, getting dataset")

# load dataset
if dataset is None:
Expand All @@ -116,10 +135,10 @@ def train_model(
)
else:
if dataset.cfg == cfg.dataset_cfg:
logger.progress(f"passed dataset has matching config, using that")
log(f"passed dataset has matching config, using that")
else:
if allow_dataset_override:
logger.progress(
log(
f"passed dataset has different config than cfg.dataset_cfg, but allow_dataset_override is True, so using passed dataset"
)
else:
Expand All @@ -146,7 +165,8 @@ def train_model(
f"{datasets_cfg_diff = }",
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
log(f"finished getting training dataset with {len(dataset)} samples")

# validation dataset, if applicable
val_dataset: MazeDataset | None = None
if cfg.train_cfg.validation_dataset_cfg is not None:
Expand All @@ -168,7 +188,7 @@ def train_model(
dataset.mazes = dataset.mazes[: split_dataset_sizes[0]]
dataset.update_self_config()
val_dataset.update_self_config()
logger.progress(
log(
f"got validation dataset by splitting training dataset into {len(dataset)} train and {len(val_dataset)} validation samples"
)
elif isinstance(cfg.train_cfg.validation_dataset_cfg, MazeDatasetConfig):
Expand All @@ -178,18 +198,16 @@ def train_model(
local_base_path=base_path,
verbose=dataset_verbose,
)
logger.progress(
f"got custom validation dataset with {len(val_dataset)} samples"
)
log(f"got custom validation dataset with {len(val_dataset)} samples")

# get dataloader and then train
dataloader: DataLoader = get_dataloader(dataset, cfg, logger)
dataloader: DataLoader = get_dataloader(dataset, cfg, log)

logger.progress("finished dataloader, passing to train()")
log("finished dataloader, passing to train()")
trained_model: ZanjHookedTransformer = train(
cfg=cfg,
dataloader=dataloader,
logger=logger,
logger=log,
output_dir=output_path,
device=device,
val_dataset=val_dataset,
Expand Down
58 changes: 30 additions & 28 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from functools import partial
from pathlib import Path
from typing import Callable

import torch
from jaxtyping import Float
Expand All @@ -16,20 +17,19 @@
from maze_transformer.tokenizer import HuggingMazeTokenizer
from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer
from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES
from maze_transformer.training.wandb_logger import WandbLogger


def collate_batch(batch: list[SolvedMaze], maze_tokenizer: MazeTokenizer) -> list[str]:
return [" ".join(maze.as_tokens(maze_tokenizer)) for maze in batch]


def get_dataloader(
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger
dataset: MazeDataset, cfg: ConfigHolder, logger: Callable
) -> DataLoader:
if len(dataset) == 0:
raise ValueError(f"Dataset is empty: {len(dataset) = }")
logger.progress(f"Loaded {len(dataset)} sequences")
logger.progress("Creating dataloader")
logger(f"Loaded {len(dataset)} sequences")
logger("Creating dataloader")
try:
dataloader: DataLoader = DataLoader(
dataset,
Expand All @@ -52,7 +52,7 @@ def get_dataloader(
def train(
cfg: ConfigHolder,
dataloader: DataLoader,
logger: WandbLogger,
logger: Callable,
output_dir: Path,
device: torch.device,
val_dataset: MazeDataset | None = None,
Expand All @@ -66,24 +66,24 @@ def train(

# init model & optimizer
if model is None:
logger.progress(f"Initializing model")
logger(f"Initializing model")
model: ZanjHookedTransformer = cfg.create_model_zanj()
model.to(device)
else:
logger.progress("Using existing model")
logger("Using existing model")

logger.summary({"device": str(device), "model.device": model.cfg.device})
logger({"device": str(device), "model.device": model.cfg.device})

logger.progress("Initializing optimizer")
logger("Initializing optimizer")
optimizer: torch.optim.Optimizer = cfg.train_cfg.optimizer(
model.parameters(),
**cfg.train_cfg.optimizer_kwargs,
)
logger.summary(dict(model_n_params=model.cfg.n_params))
logger(dict(model_n_params=model.cfg.n_params))

# add wandb run url to model
model.training_records = {
"wandb_url": logger.url,
"wandb_url": getattr(logger, "url", None),
}

# figure out whether to run evals, and validation dataset
Expand Down Expand Up @@ -116,10 +116,8 @@ def train(
key: value if not key.startswith("eval") else float("inf")
for key, value in intervals.items()
}
logger.summary(
{"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals}
)
logger.progress(
logger({"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals})
logger(
f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}"
)

Expand All @@ -128,7 +126,7 @@ def train(
# start up training
# ==============================
model.train()
logger.progress("Starting training")
logger("Starting training")

for iteration, batch in enumerate(dataloader):
# forward pass
Expand All @@ -153,7 +151,7 @@ def train(
if evals_enabled:
for interval_key, evals_dict in PathEvals.PATH_EVALS_MAP.items():
if iteration % intervals[interval_key] == 0:
logger.progress(f"Running evals: {interval_key}")
logger(f"Running evals: {interval_key}")
scores: dict[str, StatCounter] = evaluate_model(
model=model,
dataset=val_dataset,
Expand All @@ -163,12 +161,10 @@ def train(
max_new_tokens=cfg.train_cfg.evals_max_new_tokens,
)
metrics.update(scores)
logger.log_metric_hist(metrics)
logger(metrics)

if iteration % intervals["print_loss"] == 0:
logger.progress(
f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}"
)
logger(f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}")

del loss

Expand All @@ -180,19 +176,25 @@ def train(
/ TRAIN_SAVE_FILES.checkpoints
/ TRAIN_SAVE_FILES.model_checkpt_zanj(iteration)
)
logger.progress(f"Saving model checkpoint to {model_save_path.as_posix()}")
logger(f"Saving model checkpoint to {model_save_path.as_posix()}")
zanj.save(model, model_save_path)
logger.upload_model(
model_save_path, aliases=["latest", f"iter-{iteration}"]
)
try:
logger.upload_model(
model_save_path, aliases=["latest", f"iter-{iteration}"]
)
except Exception as e:
logger(f"Failed to upload model: {e}")

# save the final model
# ==============================
final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final_zanj
logger.progress(f"Saving final model to {final_model_path.as_posix()}")
logger(f"Saving final model to {final_model_path.as_posix()}")
zanj.save(model, final_model_path)
logger.upload_model(final_model_path, aliases=["latest", "final"])
try:
logger.upload_model(final_model_path, aliases=["latest", "final"])
except Exception as e:
logger(f"Failed to upload model: {e}")

logger.progress("Done training!")
logger("Done training!")

return model
Loading
Loading