Skip to content

Commit

Permalink
training loop evals (#181)
Browse files Browse the repository at this point in the history
added evals to training loop, and a bunch of options in the `TrainingConfig` related to this

mega PR, merging without review because I need to run experiments lol. 

---------

Co-authored-by: mivanit <[email protected]>
  • Loading branch information
valedan and mivanit authored Jun 15, 2023
1 parent db88ea8 commit 5aff580
Show file tree
Hide file tree
Showing 21 changed files with 1,284 additions and 228 deletions.
4 changes: 2 additions & 2 deletions examples/coverage/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 41 additions & 38 deletions examples/coverage/coverage.txt
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
Name Stmts Miss Cover Missing
------------------------------------------------------------------------------------------------------
maze_transformer\__init__.py 0 0 100%
maze_transformer\evaluation\__init__.py 0 0 100%
maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 158, 160-163, 169
maze_transformer\evaluation\eval_model.py 65 23 65% 36-45, 60-95
maze_transformer\evaluation\maze_complexity_evals.py 8 0 100%
maze_transformer\evaluation\path_evals.py 98 14 86% 19, 52-56, 83, 155, 165-169, 178-181
maze_transformer\evaluation\plot_attention.py 89 89 0% 2-261
maze_transformer\test_helpers\assertions.py 22 1 95% 44
maze_transformer\test_helpers\stub_logger.py 21 7 67% 15-17, 20, 23, 26, 29
maze_transformer\tokenizer.py 67 11 84% 13, 50-51, 57, 89-90, 96-98, 101-104
maze_transformer\training\__init__.py 0 0 100%
maze_transformer\training\config.py 121 12 90% 95, 259, 281-285, 327-328, 341, 350, 356-359, 405, 410
maze_transformer\training\train_model.py 45 3 93% 33, 60-61
maze_transformer\training\training.py 69 0 100%
maze_transformer\training\wandb_logger.py 40 3 92% 55-57
setup.py 3 3 0% 3-6
tests\conftest.py 8 0 100%
tests\integration\test_create_dataset.py 27 0 100%
tests\integration\test_eval_model.py 49 0 100%
tests\integration\test_train_model.py 8 0 100%
tests\unit\maze_transformer\evaluation\test_baseline_models.py 23 0 100%
tests\unit\maze_transformer\evaluation\test_maze_complexity_evals.py 15 0 100%
tests\unit\maze_transformer\evaluation\test_path_evals.py 48 2 96% 30-32
tests\unit\maze_transformer\test_tokenizers.py 58 0 100%
tests\unit\maze_transformer\training\config\test_base_gpt_config.py 29 0 100%
tests\unit\maze_transformer\training\config\test_cfg_save_load.py 67 0 100%
tests\unit\maze_transformer\training\config\test_config_holder.py 38 5 87% 44-50
tests\unit\maze_transformer\training\config\test_train_config.py 32 0 100%
tests\unit\maze_transformer\training\test_dataset.py 53 7 87% 44-59
tests\unit\maze_transformer\training\test_maze_dataset_construction.py 8 0 100%
tests\unit\maze_transformer\training\test_model_loading_old.py 19 0 100%
tests\unit\maze_transformer\training\test_tokenizer.py 13 0 100%
tests\unit\maze_transformer\training\test_training.py 18 0 100%
tests\unit\maze_transformer\training\zanj\test_zanj_ht_save_load.py 37 0 100%
------------------------------------------------------------------------------------------------------
TOTAL 1267 187 85%
Name Stmts Miss Cover Missing
-------------------------------------------------------------------------------------------------------
maze_transformer\__init__.py 0 0 100%
maze_transformer\evaluation\__init__.py 0 0 100%
maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 158, 160-163, 169
maze_transformer\evaluation\eval_model.py 78 28 64% 46-55, 70-105, 173-186, 251
maze_transformer\evaluation\maze_complexity_evals.py 8 0 100%
maze_transformer\evaluation\path_evals.py 106 9 92% 19, 64-68, 88, 91-95, 167, 177, 190
maze_transformer\evaluation\plot_attention.py 89 89 0% 2-261
maze_transformer\test_helpers\assertions.py 22 1 95% 44
maze_transformer\test_helpers\stub_logger.py 26 2 92% 23, 26
maze_transformer\tokenizer.py 67 11 84% 13, 50-51, 57, 89-90, 96-98, 101-104
maze_transformer\training\__init__.py 0 0 100%
maze_transformer\training\config.py 156 17 89% 80-83, 143, 190, 194, 232, 402, 424-428, 470-471, 484, 493, 499-502, 548, 553
maze_transformer\training\train_model.py 59 6 90% 34, 61-62, 131-138
maze_transformer\training\train_save_files.py 16 0 100%
maze_transformer\training\training.py 80 3 96% 61, 86-89
maze_transformer\training\wandb_logger.py 51 4 92% 56-58, 61
setup.py 3 3 0% 3-6
tests\conftest.py 8 0 100%
tests\integration\test_create_dataset.py 27 0 100%
tests\integration\test_eval_model.py 49 0 100%
tests\integration\test_train_model.py 8 0 100%
tests\integration\test_training.py 62 0 100%
tests\unit\maze_transformer\evaluation\test_baseline_models.py 23 0 100%
tests\unit\maze_transformer\evaluation\test_maze_complexity_evals.py 15 0 100%
tests\unit\maze_transformer\evaluation\test_path_evals.py 48 2 96% 30-32
tests\unit\maze_transformer\test_tokenizers.py 58 0 100%
tests\unit\maze_transformer\training\config\test_base_gpt_config.py 29 0 100%
tests\unit\maze_transformer\training\config\test_cfg_save_load.py 67 0 100%
tests\unit\maze_transformer\training\config\test_config_holder.py 38 5 87% 44-50
tests\unit\maze_transformer\training\config\test_train_cfg_intervals.py 89 0 100%
tests\unit\maze_transformer\training\config\test_train_config.py 33 0 100%
tests\unit\maze_transformer\training\test_dataset.py 53 7 87% 44-59
tests\unit\maze_transformer\training\test_get_dataloader.py 18 0 100%
tests\unit\maze_transformer\training\test_maze_dataset_construction.py 8 0 100%
tests\unit\maze_transformer\training\test_model_loading_old.py 19 0 100%
tests\unit\maze_transformer\training\test_tokenizer.py 13 0 100%
tests\unit\maze_transformer\training\zanj\test_zanj_ht_save_load.py 37 0 100%
-------------------------------------------------------------------------------------------------------
TOTAL 1532 194 87%
4 changes: 2 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ unit:
.PHONY: integration
integration:
@echo "run integration tests"
$(POETRY_RUN_PYTHON) -m pytest -s tests/integration
$(POETRY_RUN_PYTHON) -m pytest tests/integration


.PHONY: convert_notebooks
Expand All @@ -62,7 +62,7 @@ test: clean unit integration test_notebooks
.PHONY: cov
cov:
@echo "run tests and generate coverage reports"
$(POETRY_RUN_PYTHON) -m pytest --cov=. -s tests/
$(POETRY_RUN_PYTHON) -m pytest --cov=. tests/
$(POETRY_RUN_PYTHON) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt
$(POETRY_RUN_PYTHON) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg
$(POETRY_RUN_PYTHON) -m coverage html
Expand Down
109 changes: 97 additions & 12 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@

import numpy as np
import torch
from maze_dataset import SPECIAL_TOKENS, CoordTup, MazeDataset, MazeDatasetConfig
from jaxtyping import Float
from maze_dataset import (
SPECIAL_TOKENS,
CoordTup,
MazeDataset,
MazeDatasetConfig,
SolvedMaze,
)
from maze_dataset.tokenization.token_utils import (
WhenMissing,
get_context_tokens,
get_path_tokens,
remove_padding_from_token_str,
tokens_to_coords,
)
from muutils.mlutils import chunks
from muutils.statcounter import StatCounter
from transformer_lens import HookedTransformer
from transformer_lens import utils as tl_utils

from maze_transformer.evaluation.path_evals import PathEvalFunction, PathEvals
from maze_transformer.tokenizer import HuggingMazeTokenizer
from maze_transformer.training.config import ConfigHolder
from maze_transformer.training.training import TRAIN_SAVE_FILES
from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES

# pylint: disable=protected-access

Expand Down Expand Up @@ -109,10 +119,10 @@ def predict_maze_paths(

# check types
assert isinstance(
tokens_batch, list
tokens_batch, (list, tuple)
), f"tokens_batch must be a list, got {type(tokens_batch)}"
assert all(
isinstance(tokens, list) for tokens in tokens_batch
isinstance(tokens, (list, tuple)) for tokens in tokens_batch
), f"tokens_batch must be a list of lists, got {[type(tokens) for tokens in tokens_batch] = }"
assert all(
isinstance(x, str) for tokens in tokens_batch for x in tokens
Expand All @@ -131,6 +141,7 @@ def predict_maze_paths(
max_new_tokens=max_new_tokens,
verbose=verbose,
temperature=temperature,
# use_past_kv_cache=False,
)
assert isinstance(
prediction, str
Expand All @@ -154,27 +165,59 @@ def predict_maze_paths(
return paths


def evaluate_path_predictions(
solved_mazes: list[SolvedMaze],
predictions: list[list[tuple[int, int]]],
path_evals: dict[str, PathEvalFunction],
) -> dict[str, StatCounter]:
path_scores: dict[str, StatCounter] = {
name: StatCounter() for name in path_evals.keys()
}
for name, func in path_evals.items():
path_scores[name].update(
func(
maze=solved_maze.maze,
solution=np.array(solved_maze.solution),
prediction=np.array(prediction),
)
for solved_maze, prediction in zip(solved_mazes, predictions)
)

return path_scores


def evaluate_model(
model: HookedTransformer,
dataset: MazeDataset,
dataset_tokens: list[list[str]] | None = None,
eval_functions: dict[str, PathEvalFunction] | None = None,
max_new_tokens: int = 8,
batch_size: int = 64,
verbose: bool = False,
) -> dict[str, StatCounter]:
"""Run a set of eval functions on a model for a given dataset. Returns a seperate StatCounter for each eval function."""
"""Run a set of eval functions on a model for a given dataset. Returns a seperate StatCounter for each eval function.
if dataset_tokens is provided, we assume that the dataset has already been tokenized and we skip tokenization. MAKE SURE THERE IS NOT A MISMATCH BETWEEN THE DATASET AND DATASET_TOKENS
"""

if not eval_functions:
eval_functions = PathEvals.evals
# TODO: potentially model evals which aren't path evals?
eval_functions = PathEvals.EVALS

score_counters: dict[str, StatCounter] = {
name: StatCounter() for name in eval_functions.keys()
name: StatCounter() for name in eval_functions
}

for maze_batch in chunks(dataset, batch_size):
tokens_batch = [
maze.as_tokens(dataset.cfg.node_token_map) for maze in maze_batch
]
predictions = predict_maze_paths(
if dataset_tokens is None:
dataset_tokens = dataset.as_tokens(join_tokens_individual_maze=False)
else:
assert len(dataset) == len(
dataset_tokens
), f"dataset and dataset_tokens must be the same length and must be from corresponding mazes, got {len(dataset) = } and {len(dataset_tokens) = }"

for batch in chunks(zip(dataset, dataset_tokens), batch_size):
maze_batch, tokens_batch = zip(*batch)
predictions: list[str | list[tuple[int, int]]] = predict_maze_paths(
tokens_batch=tokens_batch,
data_cfg=dataset.cfg,
model=model,
Expand All @@ -194,3 +237,45 @@ def evaluate_model(
)

return score_counters


def evaluate_logits(
logits: Float[torch.Tensor, "batch pos d_vocab"],
batch: list[int],
config: ConfigHolder,
tokenizer: HuggingMazeTokenizer,
path_evals: dict[str, PathEvalFunction] | None = None,
) -> dict[str, StatCounter]:
"""Runs a set of eval functions on the provided logits. For path evals, an attempt will be made to extract a predicted path from the logits (it is assumed that the logits are an entire sequence output from training, so they contain the adj_list plus path)"""

raise NotImplementedError(
"evaluate_logits does not function correctly, and at the moment there are only path evals anyway"
)

scores: dict[str, StatCounter] = {}

if path_evals:
# TODO: this is pretty much wrong -- sampling from the logits over the sequence should not produce a valid path
sampled_logits = tl_utils.sample_logits(logits)
prediction_tokens = tokenizer.batch_decode(sampled_logits)
predicted_paths = []
for tokens in prediction_tokens:
# this returns first path_start to end of list. Early in training there may be multiple path_start tokens, so results should be treated with caution
path_tokens = get_path_tokens(tokens.split(" "))
path_coords = tokens_to_coords(
path_tokens, maze_data_cfg=config.dataset_cfg, when_noncoord="skip"
)
predicted_paths.append(cast(list[tuple[int, int]], path_coords))

maze_tokens = [
remove_padding_from_token_str(token_str)
for token_str in tokenizer.batch_decode(batch)
]

solved_mazes = [
SolvedMaze.from_tokens(tokens.split(" "), config.dataset_cfg)
for tokens in maze_tokens
]
scores |= evaluate_path_predictions(solved_mazes, predicted_paths, path_evals)

return scores
38 changes: 26 additions & 12 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,21 @@ def is_adjacent(node1: Coord, node2: Coord) -> bool:
class PathEvals:
"""array path based eval functions"""

evals: dict[str, PathEvalFunction] = {}
# We split evals into fast and slow. Fast ones can be used more frequently during training
fast: dict[str, PathEvalFunction] = {}
slow: dict[str, PathEvalFunction] = {}

@register_method(evals)
PATH_EVALS_MAP: dict[str, dict[str, PathEvalFunction]] = {
"eval_fast": fast,
"eval_slow": slow,
}

@classmethod
@property
def EVALS(cls):
return {**cls.fast, **cls.slow}

@register_method(fast)
@staticmethod
def node_overlap(solution: CoordArray, prediction: CoordArray, **_) -> float:
"""number of shared nodes (any order) / total number of (unique) nodes in solution"""
Expand All @@ -57,7 +69,7 @@ def node_overlap(solution: CoordArray, prediction: CoordArray, **_) -> float:

return len(prediction_set & solution_set) / len(solution_set)

@register_method(evals)
@register_method(fast)
@staticmethod
def num_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:
"""number of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze"""
Expand All @@ -68,10 +80,12 @@ def num_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:

return n_adj

@register_method(evals)
@register_method(fast)
@staticmethod
def fraction_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:
"""fraction of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze"""
if len(prediction) == 0:
return 0

if len(prediction) <= 1:
warnings.warn(
Expand All @@ -82,7 +96,7 @@ def fraction_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:

return PathEvals.num_connections_adjacent_lattice(prediction) / len(prediction)

@register_method(evals)
@register_method(fast)
@staticmethod
def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> float:
"""number of connections in prediction which are valid paths on the maze"""
Expand All @@ -94,7 +108,7 @@ def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> fl

return n_connected

@register_method(evals)
@register_method(fast)
@staticmethod
def fraction_connections_adjacent(
maze: LatticeMaze, prediction: CoordArray, **_
Expand All @@ -106,20 +120,20 @@ def fraction_connections_adjacent(
num_connections, 1.0
)

@register_method(evals)
@register_method(fast)
@staticmethod
def exact_path_predicted(
solution: CoordArray, prediction: CoordArray, **_
) -> float:
"""Was the maze successfully solved?"""
return float(np.array_equal(solution, prediction))

@register_method(evals)
@register_method(fast)
@staticmethod
def solution_length(solution: CoordArray, **_) -> float:
return float(len(solution))

@register_method(evals)
@register_method(fast)
@staticmethod
def streak_length_until_incorrect(
solution: CoordArray,
Expand All @@ -143,7 +157,7 @@ def streak_length_until_incorrect(

return streak_length

@register_method(evals)
@register_method(fast)
@staticmethod
def distance_between_end_nodes(
solution: MazePath, prediction: MazePath, **_
Expand All @@ -154,7 +168,7 @@ def distance_between_end_nodes(

return np.linalg.norm(solution[-1] - prediction[-1])

@register_method(evals)
@register_method(fast)
@staticmethod
def corner_jumps(prediction: MazePath, **_) -> float:
"""Looks for corner jumps in the predicted path. A corner jump is if the transformer predicts predicts
Expand All @@ -168,7 +182,7 @@ def corner_jumps(prediction: MazePath, **_) -> float:
normed_distances = np.linalg.norm(distance_between_nodes, axis=1)
return np.count_nonzero(normed_distances == np.sqrt(2))

@register_method(evals)
@register_method(fast)
@staticmethod
def average_predicted_step_size(prediction: MazePath, **_) -> float:
"""Returns average step size in the predicted path."""
Expand Down
Loading

0 comments on commit 5aff580

Please sign in to comment.