Skip to content

Commit

Permalink
Fix dataset split, pt to zanj model conversion code (#190)
Browse files Browse the repository at this point in the history
* fixed train/val data split
* re-run notebooks
* wandb to zanj notebook
* fixed zero-size dataset issue (only appeared somewhere in this commit history, not sure where)
mivanit authored Jun 16, 2023
1 parent 5aff580 commit 4d1ae4e
Showing 10 changed files with 464 additions and 147 deletions.
4 changes: 3 additions & 1 deletion maze_transformer/test_helpers/stub_logger.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,9 @@ def summary(self, *args, **kwargs) -> None:
self._log("Summary logged.", args, kwargs)

def progress(self, message: str) -> None:
self._log(f"[INFO] - {message}")
msg: str = f"[INFO] - {message}"
print(msg)
self._log(msg)

@property
def url(self) -> str:
2 changes: 1 addition & 1 deletion maze_transformer/training/config.py
Original file line number Diff line number Diff line change
@@ -318,7 +318,7 @@ def summary(self) -> dict:
eval_fast=4,
eval_slow=2,
),
validation_dataset_cfg=10,
validation_dataset_cfg=1,
),
TrainConfig(
name="tiny-v1",
15 changes: 11 additions & 4 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
from muutils.mlutils import get_device
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader

from maze_transformer.training.config import (
GPT_CONFIGS,
@@ -116,13 +116,20 @@ def train_model(
if cfg.train_cfg.validation_dataset_cfg is not None:
if isinstance(cfg.train_cfg.validation_dataset_cfg, int):
# split the training dataset
assert len(dataset) > cfg.train_cfg.validation_dataset_cfg, (
f"{cfg.train_cfg.validation_dataset_cfg = } "
+ f"is greater than the length of the training dataset: {len(dataset) = }"
)
split_dataset_sizes: tuple[int, int] = [
len(dataset) - cfg.train_cfg.validation_dataset_cfg,
cfg.train_cfg.validation_dataset_cfg,
]
sub_dataset, sub_val_dataset = random_split(dataset, split_dataset_sizes)
dataset = sub_dataset.dataset
val_dataset = sub_val_dataset.dataset
val_dataset = MazeDataset(
cfg.dataset_cfg,
mazes=dataset.mazes[-split_dataset_sizes[1] :],
generation_metadata_collected=dataset.generation_metadata_collected,
)
dataset.mazes = dataset.mazes[: split_dataset_sizes[0]]
dataset.update_self_config()
val_dataset.update_self_config()
logger.progress(
23 changes: 17 additions & 6 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
@@ -25,14 +25,25 @@ def collate_batch(batch: list[SolvedMaze], config: MazeDatasetConfig) -> list[st
def get_dataloader(
dataset: MazeDataset, cfg: ConfigHolder, logger: WandbLogger
) -> DataLoader:
if len(dataset) == 0:
raise ValueError(f"Dataset is empty: {len(dataset) = }")
logger.progress(f"Loaded {len(dataset)} sequences")
logger.progress("Creating dataloader")
dataloader: DataLoader = DataLoader(
dataset,
collate_fn=partial(collate_batch, config=cfg.dataset_cfg),
batch_size=cfg.train_cfg.batch_size,
**cfg.train_cfg.dataloader_cfg,
)
try:
dataloader: DataLoader = DataLoader(
dataset,
collate_fn=partial(collate_batch, config=cfg.dataset_cfg),
batch_size=cfg.train_cfg.batch_size,
**cfg.train_cfg.dataloader_cfg,
)
except ValueError as e:
raise ValueError(
"Error creating dataloader with:",
f"{len(dataset) = }",
f"{cfg.train_cfg.batch_size = }",
f"{cfg.train_cfg.dataloader_cfg = }",
f"error: {e}",
) from e

return dataloader

172 changes: 172 additions & 0 deletions maze_transformer/utils/get_pt_model_wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import math
from pathlib import Path

import torch
import wandb
from maze_dataset import MazeDatasetConfig
from muutils.misc import shorten_numerical_to_str
from transformer_lens import HookedTransformer
from wandb.sdk.wandb_run import Artifact, Run

from maze_transformer.training.config import (
BaseGPTConfig,
ConfigHolder,
TrainConfig,
ZanjHookedTransformer,
)


def get_step(artifact: Artifact) -> int:
# Find the alias beginning with "step="
step_alias: list[str] = [
alias for alias in artifact.aliases if alias.startswith("step=")
]
if len(step_alias) != 1: # if we have multiple, skip as well
return -1
return int(step_alias[0].split("=")[-1])


def load_model(
config_holder: ConfigHolder, model_path: str, fold_ln: bool = True
) -> HookedTransformer:
model: HookedTransformer = config_holder.create_model()
state_dict: dict = torch.load(model_path, map_location=model.cfg.device)
model.load_and_process_state_dict(
state_dict,
fold_ln=False,
center_writing_weights=True,
center_unembed=True,
refactor_factored_attn_matrices=True,
)
model.process_weights_(fold_ln=fold_ln)
model.setup() # Re-attach layernorm hooks by calling setup
model.eval()
return model


def load_wandb_run(
project="aisc-search/alex",
run_id="sa973hyn",
output_path="./downloaded_models",
checkpoint=None,
) -> tuple[HookedTransformer, ConfigHolder]:
api: wandb.Api = wandb.Api()

artifact_name: str = f"{project.rstrip('/')}/{run_id}"

run: Run = api.run(artifact_name)
wandb_cfg: wandb.config.Config = run.config # Get run configuration

# -- Get / Match checkpoint --
if checkpoint is not None:
# Match checkpoint
available_checkpoints = [
artifact for artifact in run.logged_artifacts() if artifact.type == "model"
]
available_checkpoints = list(run.logged_artifacts())
artifact = [aft for aft in available_checkpoints if get_step(aft) == checkpoint]
if len(artifact) != 1:
print(f"Could not find checkpoint {checkpoint} in {artifact_name}")
print("Available checkpoints:")
[
print(artifact.name, "| Steps: ", get_step(artifact))
for artifact in available_checkpoints
]
return

artifact = artifact[0]
print("Loading checkpoint", checkpoint)
else:
# Get latest checkpoint
print("Loading latest checkpoint")
artifact_name = f"{artifact_name}:latest"
artifact = api.artifact(artifact_name)
checkpoint = get_step(artifact)

# -- Initalize configurations --
# Model cfg
model_properties = {
k: wandb_cfg[k] for k in ["act_fn", "d_model", "d_head", "n_layers"]
}
model_cfg: BaseGPTConfig = BaseGPTConfig(
name=f"model {run_id}",
weight_processing={
"are_layernorms_folded": True,
"are_weights_processed": True,
},
**model_properties,
)

# Dataset cfg
grid_n: int = math.sqrt(wandb_cfg["d_vocab"] - 11) #! Jank
assert grid_n == int(
grid_n
), "grid_n must be a perfect square + 11" # check integer
ds_cfg: MazeDatasetConfig = MazeDatasetConfig(
name=wandb_cfg.get("dataset_name", "no_name"), grid_n=int(grid_n), n_mazes=-1
)

cfg: ConfigHolder = ConfigHolder(
model_cfg=model_cfg,
dataset_cfg=ds_cfg,
train_cfg=TrainConfig(
name=f"artifact '{artifact_name}', checkpoint '{checkpoint}'"
),
)
download_path: Path = (
Path(output_path)
/ f'{artifact.name.split(":")[0]}'
/ f"model.iter_{checkpoint}.pt"
)
#! Account for final checkpoint
if not download_path.exists():
artifact.download(root=download_path.parent)
print(f"Downloaded model to {download_path}")
else:
print(f"Model already downloaded to {download_path}")

print("Loading model")
model: HookedTransformer = load_model(cfg, download_path, fold_ln=True)
return model, cfg


def load_wandb_pt_model_as_zanj(
run_id: str,
project: str = "aisc-search/alex",
checkpoint: int | None = None,
output_path: str = "./downloaded_models",
save_zanj_model: bool = True,
verbose: bool = True,
) -> ZanjHookedTransformer:
model_kwargs: dict = dict(
project=project,
run_id=run_id,
checkpoint=checkpoint,
)
model: HookedTransformer
cfg: ConfigHolder
model, cfg = load_wandb_run(**model_kwargs)
if verbose:
print(f"{type(model) = } {type(cfg) = }")

model_zanj: ZanjHookedTransformer = ZanjHookedTransformer(cfg)
model_zanj.load_state_dict(model.state_dict())
model_zanj.training_records = {
"load_wandb_run_kwargs": model_kwargs,
"train_cfg.name": cfg.train_cfg.name,
}
if verbose:
print(
f"loaded model with {shorten_numerical_to_str(model_zanj.num_params())} parameters"
)
print(model_zanj.training_records)

if save_zanj_model:
model_zanj_save_path: Path = (
Path(output_path) / f"wandb.{model_kwargs['run_id']}.zanj"
)
model_zanj.save(model_zanj_save_path)
if verbose:
print(f"Saved model to {model_zanj_save_path.as_posix()}")

return model_zanj
Loading

0 comments on commit 4d1ae4e

Please sign in to comment.