Skip to content

Commit

Permalink
Zanj integration: datasets & training (#177)
Browse files Browse the repository at this point in the history
Porting datasets and training loop to use ZANJ, as well as many, many other changes to the dataset.

# configs

Modifying configs from the command line is now easier!

- `ConfigHolder.get_config_multisource()`
  - takes one of: config object, a file to read the config from, or a list of names to get presets for each of the sub-configs
  - a dotlist-dict to modify any parameters of the config
- `GPTDataset().to_fname()` used to generate filename for saving config (and also to find matching config to load/download). `MazeDatasetConfig` also implements this in a custom way
- `MazeDatasetConfig` now has a `maze_ctor_kwargs` field, for passing keyword arguments to maze generation (see #183)

# maze dataset

You can now get a `MazeDataset` from just a config -- it will load, download, or generate a dataset on the fly. The mess of ways of storing a dataset we had before is now gone -- a `MazeDataset` contains a list of `SolvedMaze`, and it will return one of those when you call `__getitem__`. We also added filters and fixed some parallelization issues!

- `GPTDataset().from_config()` as a new, simplified version of getting a dataset: simply pass a config, and it will attempt to load from local directory, download, or generate. any of these can be disabled, and kwargs (for things like # of cores to use) are passed down.
- canonical representation of the dataset as list of `SolvedMaze`
- `mazes_objs, mazes_tokens, mazes_array` are now cached properties. they will work, but might be slow due to no parallelization
- `MazeDataset.__getitem__()` now returns a `SolvedMaze`
- `create_dataset()` deprecated but should still work. remove this?
- filtering! you can specify filters in the config under the `applied_filters` field, or you can call `dataset.filter_by.your_filter_func(your_arg=your_val)`. Both of these work the same under the hood.
- can specify in `from_config()` whether to run in parallel or not (default is no). this is useful since for small datasets, parallelization has huge overhead. tests are now much faster.
- there may have been some issues to parallelization and using the same fixed seed across all processes. This was fixed in #183 , but in a hacky way

# training

Models now saved as ZANJ objects, and the command line interface is improved.

- `train()` now:
  - saves models as ZANJ
  - returns the trained `ZanjHookedTransformer`
- `train_model()`:
  - now returns `TrainingResult` which contains output path, model, and eventually logging info perhaps?
  - for config, interface inherited from `ConfigHolder.get_config_multisource()` and kwargs are passed as modification dict

---------

Co-authored-by: mivanit <[email protected]>
Co-authored-by: Dan Valentine <[email protected]>
Co-authored-by: Can Rager <[email protected]>
Co-authored-by: canrager <[email protected]>
  • Loading branch information
4 people authored Apr 28, 2023
1 parent bf8605f commit 06c8181
Show file tree
Hide file tree
Showing 44 changed files with 2,358 additions and 878 deletions.
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ convert_notebooks:

.PHONY: test_notebooks
test_notebooks: convert_notebooks
@echo "run tests on converted notebooks in $(CONVERTED_NOTEBOOKS_TEMP_DIR) using $(HELPERS_DIR)/run_notebook_tests.py"
python $(HELPERS_DIR)/run_notebook_tests.py --notebooks-dir=$(NOTEBOOKS_DIR) --converted-notebooks-temp-dir=$(CONVERTED_NOTEBOOKS_TEMP_DIR)


Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _predict_next_step(
unvisited_neighbors = [coord for coord in neighbors if coord not in path]

# if the current path is already as long as the solution, there can be no correct next step
correct_step = solution[len(path)] if len(solution) > len(path) else None
correct_step = tuple(solution[len(path)]) if len(solution) > len(path) else None

if len(unvisited_neighbors) == 0:
return SPECIAL_TOKENS["path_end"]
Expand Down Expand Up @@ -89,7 +89,7 @@ def _generate_path(
maze = LatticeMaze.from_tokens(tokens)
origin_coord = self.config.dataset_cfg.token_node_map[get_origin_token(tokens)]
target_coord = self.config.dataset_cfg.token_node_map[get_target_token(tokens)]
solution = maze.find_shortest_path(origin_coord, target_coord)
solution = maze.find_shortest_path(origin_coord, target_coord).tolist()

existing_path = tokens_to_coords(
get_path_tokens(tokens), self.config.dataset_cfg
Expand Down
14 changes: 5 additions & 9 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from maze_transformer.evaluation.path_evals import PathEvalFunction, PathEvals
from maze_transformer.generation.constants import SPECIAL_TOKENS
from maze_transformer.generation.lattice_maze import SolvedMaze
from maze_transformer.training.config import ConfigHolder
from maze_transformer.training.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.training.training import TRAIN_SAVE_FILES
Expand Down Expand Up @@ -150,15 +149,12 @@ def evaluate_model(
name: StatCounter() for name in eval_functions.keys()
}

for batch in chunks(dataset.mazes_tokens, batch_size):
# TODO: This won't be needed after #124, then we can call mazes_objs instead
# https://github.com/orgs/AISC-understanding-search/projects/1/views/1?pane=issue&itemId=23879308
solved_mazes: SolvedMaze = [
SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in batch
for maze_batch in chunks(dataset, batch_size):
tokens_batch = [
maze.as_tokens(dataset.cfg.node_token_map) for maze in maze_batch
]

predictions = predict_maze_paths(
tokens_batch=batch,
tokens_batch=tokens_batch,
data_cfg=dataset.cfg,
model=model,
max_new_tokens=max_new_tokens,
Expand All @@ -173,7 +169,7 @@ def evaluate_model(
prediction=np.array(prediction),
model=model,
)
for sm, prediction in zip(solved_mazes, predictions)
for sm, prediction in zip(maze_batch, predictions)
)

return score_counters
12 changes: 12 additions & 0 deletions maze_transformer/evaluation/maze_complexity_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing

from maze_transformer.generation.lattice_maze import SolvedMaze
from maze_transformer.utils.utils import register_method

MAZE_COMPLEXITY_EVALS: dict[str, typing.Callable[[SolvedMaze], float]] = dict()


class MazeComplexityEvals:
@register_method(MAZE_COMPLEXITY_EVALS)
def solution_length(maze: SolvedMaze) -> float:
return len(maze.solution)
15 changes: 7 additions & 8 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from typing import Iterable, Optional, Protocol, TypeAlias
import typing

import numpy as np
from jaxtyping import Int

from maze_transformer.generation.constants import Coord, CoordArray, CoordTup
from maze_transformer.generation.lattice_maze import LatticeMaze
from maze_transformer.utils.utils import register_method

# pylint: disable=unused-argument
MazePath: TypeAlias = Int[np.ndarray, "node x_y_pos"]
MazePath = CoordArray


class PathEvalFunction(Protocol):
class PathEvalFunction(typing.Protocol):
def __call__(
self,
maze: Optional[LatticeMaze] = None,
solution: Optional[CoordArray] = None,
prediction: Optional[CoordArray] = None,
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float:
...


def path_as_segments_iter(path: CoordArray) -> Iterable[tuple]:
def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
"""
Iterate over the segments of a path (ie each consecutive pair).
"""
Expand Down
92 changes: 66 additions & 26 deletions maze_transformer/generation/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import random
import warnings
from typing import Any, Callable

import numpy as np

from maze_transformer.generation.constants import CoordArray
from maze_transformer.generation.lattice_maze import (
NEIGHBORS_MASK,
ConnectionList,
Coord,
CoordTup,
LatticeMaze,
SolvedMaze,
)
Expand All @@ -18,9 +19,11 @@ class LatticeMazeGenerators:

@staticmethod
def gen_dfs(
grid_shape: Coord | CoordTup,
start_coord: Coord | None = None,
grid_shape: Coord,
lattice_dim: int = 2,
n_accessible_cells: int | None = None,
max_tree_depth: int | None = None,
start_coord: Coord | None = None,
) -> LatticeMaze:
"""generate a lattice maze using depth first search, iterative
Expand All @@ -35,28 +38,39 @@ def gen_dfs(
4. Mark the chosen cell as visited and push it to the stack
"""

grid_shape = np.array(grid_shape)

# initialize the maze with no connections
connection_list: np.ndarray = np.zeros(
(lattice_dim, grid_shape[0], grid_shape[1]), dtype=bool
)

# Default values if no constraints have been passed
grid_shape: Coord = np.array(grid_shape)
n_total_cells: int = np.prod(grid_shape)
if n_accessible_cells is None:
n_accessible_cells = n_total_cells
if max_tree_depth is None:
max_tree_depth = (
2 * n_total_cells
) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
if start_coord is None:
start_coord: Coord = (
random.randint(0, grid_shape[0] - 1),
random.randint(0, grid_shape[1] - 1),
start_coord: Coord = np.random.randint(
0,
np.maximum(grid_shape - 1, 1),
size=2,
)
else:
start_coord = np.array(start_coord)

# print(f"{grid_shape = } {start_coord = }")
# initialize the maze with no connections
connection_list: ConnectionList = np.zeros(
(lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_
)

# initialize the stack with the target coord
visited_cells: set[tuple[int, int]] = set()
visited_cells.add(tuple(start_coord))
stack: list[Coord] = [start_coord]

# loop until the stack is empty
while stack:
# initialize tree_depth_counter
current_tree_depth: int = 1

# loop until the stack is empty or n_connected_cells is reached
while stack and (len(visited_cells) < n_accessible_cells):
# get the current coord from the stack
current_coord: Coord = stack.pop()

Expand All @@ -73,7 +87,10 @@ def gen_dfs(
)
]

if unvisited_neighbors_deltas:
# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
if unvisited_neighbors_deltas and (
current_tree_depth <= max_tree_depth / 2
):
stack.append(current_coord)

# choose one of the unvisited neighbors
Expand All @@ -92,22 +109,24 @@ def gen_dfs(
visited_cells.add(tuple(chosen_neighbor))
stack.append(chosen_neighbor)

# Update current tree depth
current_tree_depth += 1
else:
current_tree_depth -= 1

return LatticeMaze(
connection_list=connection_list,
generation_meta=dict(
func_name="gen_dfs",
grid_shape=grid_shape,
start_coord=start_coord,
visited_cells=visited_cells,
n_accessible_cells=n_accessible_cells,
max_tree_depth=max_tree_depth,
fully_connected=(len(visited_cells) == n_accessible_cells),
),
)

@classmethod
def gen_dfs_with_solution(cls, grid_shape: Coord) -> SolvedMaze:
maze: LatticeMaze = cls.gen_dfs(grid_shape)
solution: CoordArray = np.array(maze.generate_random_path())

return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

@staticmethod
def gen_wilson(
grid_shape: Coord,
Expand Down Expand Up @@ -137,9 +156,9 @@ def neighbor(current: Coord, direction: int) -> Coord:

# A connection list only contains two elements: one boolean matrix indicating all the
# downwards connections in the maze, and one boolean matrix indicating the rightwards connections.
connection_list: np.ndarray = np.zeros((2, rows, cols), dtype=bool)
connection_list: np.ndarray = np.zeros((2, rows, cols), dtype=np.bool_)

connected = np.zeros(grid_shape, dtype=bool)
connected = np.zeros(grid_shape, dtype=np.bool_)
direction_matrix = np.zeros(grid_shape, dtype=int)

# Mark a random cell as connected
Expand Down Expand Up @@ -198,12 +217,33 @@ def neighbor(current: Coord, direction: int) -> Coord:
generation_meta=dict(
func_name="gen_wilson",
grid_shape=grid_shape,
fully_connected=True,
),
)

@classmethod
def gen_dfs_with_solution(cls, grid_shape: Coord):
warnings.warn(
"gen_dfs_with_solution is deprecated, use get_maze_with_solution instead",
DeprecationWarning,
)
return get_maze_with_solution("gen_dfs", grid_shape)


# TODO: use the thing @valedan wrote for the evals function to make this automatic?
GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = {
"gen_dfs": LatticeMazeGenerators.gen_dfs,
"gen_wilson": LatticeMazeGenerators.gen_wilson,
}


def get_maze_with_solution(
gen_name: str,
grid_shape: Coord,
maze_ctor_kwargs: dict | None = None,
) -> SolvedMaze:
if maze_ctor_kwargs is None:
maze_ctor_kwargs = dict()
maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)
solution: CoordArray = np.array(maze.generate_random_path())
return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)
Loading

0 comments on commit 06c8181

Please sign in to comment.