Skip to content

Commit

Permalink
fix imports and notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
valedan committed Mar 27, 2023
1 parent a862459 commit 39a5f7b
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 225 deletions.
11 changes: 3 additions & 8 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import json
import typing
from pathlib import Path

import torch
from muutils.tensor_utils import NDArray
from transformer_lens import HookedTransformer

from maze_transformer.generation.latticemaze import (
SPECIAL_TOKENS,
CoordTup,
LatticeMaze,
)
from maze_transformer.evaluation.path_evals import MazePath
from maze_transformer.generation.constants import SPECIAL_TOKENS
from maze_transformer.generation.latticemaze import LatticeMaze
from maze_transformer.training.config import ConfigHolder
from maze_transformer.training.mazedataset import MazeDatasetConfig
from maze_transformer.training.tokenizer import SPECIAL_TOKENS
from maze_transformer.training.training import TRAIN_SAVE_FILES
from maze_transformer.utils.token_utils import decode_maze_tokens_to_coords

Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ def all_functions(cls) -> dict[str, PathEvalFunction]:
for name, func in cls.__dict__.items()
if not name.startswith("_") and name not in excluded
}
}
}
8 changes: 2 additions & 6 deletions maze_transformer/generation/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

import numpy as np

from maze_transformer.generation.latticemaze import (
NEIGHBORS_MASK,
Coord,
LatticeMaze,
SolvedMaze,
)
from maze_transformer.generation.constants import NEIGHBORS_MASK, Coord
from maze_transformer.generation.latticemaze import LatticeMaze, SolvedMaze


class LatticeMazeGenerators:
Expand Down
8 changes: 8 additions & 0 deletions maze_transformer/generation/latticemaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from muutils.misc import list_split
from muutils.tensor_utils import NDArray

from maze_transformer.generation.constants import (
NEIGHBORS_MASK,
SPECIAL_TOKENS,
Coord,
CoordArray,
CoordTup,
)


def coord_str_to_tuple(coord_str: str) -> CoordTup:
"""convert a coordinate string to a tuple"""
Expand Down
10 changes: 3 additions & 7 deletions maze_transformer/training/mazedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
from muutils.tensor_utils import DTYPE_MAP, ATensor, NDArray
from tqdm import tqdm

from maze_transformer.generation.constants import SPECIAL_TOKENS, CoordArray, CoordTup
from maze_transformer.generation.generators import GENERATORS_MAP, LatticeMazeGenerators
from maze_transformer.generation.latticemaze import (
CoordArray,
CoordTup,
LatticeMaze,
SolvedMaze,
)
from maze_transformer.generation.latticemaze import LatticeMaze, SolvedMaze
from maze_transformer.training.dataset import GPTDataset, GPTDatasetConfig, IndexedArray
from maze_transformer.training.tokenizer import SPECIAL_TOKENS, maze_to_tokens
from maze_transformer.training.tokenizer import maze_to_tokens


@dataclass(kw_only=True)
Expand Down
8 changes: 2 additions & 6 deletions maze_transformer/training/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import BatchEncoding

from maze_transformer.generation.latticemaze import (
SPECIAL_TOKENS,
Coord,
CoordTup,
LatticeMaze,
)
from maze_transformer.generation.constants import SPECIAL_TOKENS, Coord, CoordTup
from maze_transformer.generation.latticemaze import LatticeMaze

if TYPE_CHECKING:
from maze_transformer.training.config import ConfigHolder, MazeDatasetConfig
Expand Down
208 changes: 16 additions & 192 deletions notebooks/eval_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -28,13 +28,13 @@
"from maze_transformer.utils.notebook_utils import configure_notebook\n",
"from maze_transformer.generation.latticemaze import LatticeMaze\n",
"from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat\n",
"from maze_transformer.evaluation.eval_model import MazePath, ArrMazePath, load_model_with_configs, predict_maze_path\n",
"from maze_transformer.evaluation.pathdist import MazeEvalFunction, ArrMazeEvalFunction, MazeEvalFuncs, ArrMazeEvalFuncs"
"from maze_transformer.evaluation.eval_model import load_model_with_configs, predict_maze_path\n",
"from maze_transformer.evaluation.path_evals import MazePath, PathEvals, PathEvalFunction"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"collapsed": false,
"pycharm": {
Expand All @@ -53,13 +53,13 @@
"# this should point towards a directory containing a run. \n",
"# If you don't have any runs, you can create a dataset with `poetry run python scripts/create_dataset.py create ./data/maze 10 --grid_n=4`\n",
"# Then train a model with poetry run python scripts/train_model.py ./data/maze/g4-n10`\n",
"run_path = Path(\"../data/maze/g6-n5M/trained_model\")\n",
"run_path = Path(\"../data/maze/g4-n100\")\n",
"\n",
"assert run_path.exists(), f\"Run path {run_path.as_posix()} does not exist\"\n",
"model_path = list(sorted(run_path.glob(\"**/model.final.pt\"), key=os.path.getmtime))[\n",
"\t-1\n",
"].resolve()\n",
"maze_path = run_path.parent / \"maze_tokens.jsonl\""
"maze_path = run_path / \"maze_tokens.jsonl\""
]
},
{
Expand Down Expand Up @@ -151,20 +151,8 @@
},
"outputs": [],
"source": [
"EvalFuncTuple = tuple[typing.Literal[\"arr\", \"list\"], MazeEvalFunction|ArrMazeEvalFunction]\n",
"\n",
"ALL_PATHDIST_FUNCS: dict[str, EvalFuncTuple] = {\n",
"\t**{\n",
"\t\tname: (\"arr\", func)\n",
"\t\tfor name, func in ArrMazeEvalFuncs.__dict__.items()\n",
"\t\tif not name.startswith(\"_\")\n",
"\t},\n",
"\t**{\n",
"\t\tname: (\"list\", func)\n",
"\t\tfor name, func in MazeEvalFuncs.__dict__.items()\n",
"\t\tif not name.startswith(\"_\")\n",
"\t},\n",
"}\n",
"\n",
"ALL_PATHDIST_FUNCS: dict[str, PathEvalFunction] = PathEvals.all_functions()\n",
"\n",
"print(ALL_PATHDIST_FUNCS)"
]
Expand All @@ -183,7 +171,7 @@
"def evaluate_model_pathdist_scores(\n",
"\t\tmodel_path: Path,\n",
"\t\tmaze_tokens_path: Path,\n",
"\t\tpathdist_functions: dict[str, EvalFuncTuple]|None = ALL_PATHDIST_FUNCS,\n",
"\t\tpathdist_functions: dict[str, PathEvalFunction]|None = ALL_PATHDIST_FUNCS,\n",
"\t\tn_tokens_pred: int = 8,\n",
"\t\tn_mazes: int = 64,\n",
"\t\tverbose: bool = False,\n",
Expand Down Expand Up @@ -228,18 +216,10 @@
"\n",
"\t# evaluate\n",
"\tpathdist_scores: dict[str, StatCounter] = dict()\n",
"\tfor name, (pathdist_type, pathdist_func) in pathdist_functions.items():\n",
"\t\tif pathdist_type == \"list\":\t\t\tpathdist_scores[name] = StatCounter(\n",
"\t\t\t\tpathdist_func(maze, p_true, p_pred)\n",
"\t\t\t\tfor maze, p_true, p_pred in mazes_solved\n",
"\t\t\t)\n",
"\t\telif pathdist_type == \"arr\":\n",
"\t\t\tpathdist_scores[name] = StatCounter(\n",
"\t\t\t\tpathdist_func(maze, p_true, p_pred)\n",
"\t\t\t\tfor maze, p_true, p_pred in mazes_solved_arrpath\n",
"\t\t\t)\n",
"\t\telse:\n",
"\t\t\traise ValueError(f\"Invalid pathdist_type: {pathdist_type}\")\n",
"\tfor name, pathdist_func in pathdist_functions.items():\n",
"\t\tpathdist_scores[name] = StatCounter(\n",
"\t\t\tpathdist_func(maze, p_true, p_pred)\n",
"\t\t\tfor maze, p_true, p_pred in mazes_solved_arrpath)\n",
"\n",
"\treturn pathdist_scores\n",
"\n",
Expand All @@ -261,7 +241,7 @@
"\t\trun_path: Path, # Path to run, not model.final.pt or checkpoints\n",
"\t\tmaze_tokens_path: Path,\n",
"\t\tcheckpoint_indices: list[int]|None = None,\n",
"\t\tpathdist_functions: dict[str, EvalFuncTuple]|None = ALL_PATHDIST_FUNCS,\n",
"\t\tpathdist_functions: dict[str, PathEvalFunction]|None = ALL_PATHDIST_FUNCS,\n",
"\t\tskip_every_nth: int = 1,\n",
"\t\tn_tokens_pred: int = 8,\n",
"\t\tn_mazes: int = 10,\n",
Expand Down Expand Up @@ -313,7 +293,7 @@
"\tmaze_tokens_path = maze_path,\n",
"\tn_mazes = 10,\n",
"\t# skip_every_nth=10,\n",
"\tcheckpoint_idxs=[64,640064],\n",
"\t# checkpoint_idxs=[64,640064],\n",
"\t# verbose = True,\n",
")"
]
Expand Down Expand Up @@ -386,162 +366,6 @@
"source": [
"plot_pathdist_scores(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down Expand Up @@ -571,4 +395,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
5 changes: 3 additions & 2 deletions notebooks/plot_attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
"from maze_transformer.utils.notebook_utils import configure_notebook\n",
"from maze_transformer.generation.latticemaze import LatticeMaze, SolvedMaze\n",
"from maze_transformer.generation.generators import LatticeMazeGenerators\n",
"from maze_transformer.training.tokenizer import maze_to_tokens, SPECIAL_TOKENS, HuggingMazeTokenizer\n",
"from maze_transformer.training.tokenizer import maze_to_tokens, HuggingMazeTokenizer\n",
"from maze_transformer.evaluation.plot_maze import plot_multi_paths, PathFormat\n",
"from maze_transformer.utils.token_utils import decode_maze_tokens_to_coords\n",
"from maze_transformer.evaluation.eval_model import load_model_with_configs"
"from maze_transformer.evaluation.eval_model import load_model_with_configs\n",
"from maze_transformer.generation.constants import SPECIAL_TOKENS\n"
]
},
{
Expand Down
Loading

0 comments on commit 39a5f7b

Please sign in to comment.