Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zanj integration: datasets & training #177

Merged
merged 89 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
25d745d
wip
mivanit Mar 28, 2023
da2e05f
Merge branch 'zanj-integration' into zanj-integration-2
mivanit Mar 28, 2023
7fdbdb0
wip
mivanit Mar 28, 2023
f7abcb0
bump muutils to 0.3.3, some zanj tests working with that
mivanit Mar 28, 2023
a31d4ba
misc
mivanit Mar 29, 2023
705e1f6
something with layernorm is causing the tensor elements not to match up
mivanit Mar 30, 2023
34a62fc
???
mivanit Mar 30, 2023
a6a5b32
exact loading of model works!
mivanit Apr 1, 2023
0181b02
ugh not quite, only working if layernorm folding disabled
mivanit Apr 1, 2023
9e2fe97
wip
mivanit Apr 1, 2023
07aa160
zanj save/load tests passing?
mivanit Apr 2, 2023
e1b28b4
fixed some unit tests, test_eval_model still fails >:(
mivanit Apr 2, 2023
84d3ae8
so confused, test only fails when model generated via training?
mivanit Apr 3, 2023
2019ed4
merge with main (and bump muutils to 0.3.6)
mivanit Apr 6, 2023
570c2b1
fixed folding issue
mivanit Apr 6, 2023
1db5c61
Merge branch 'add-notebook-testing' into zanj-integration-2
mivanit Apr 6, 2023
075ff2b
bump muutils to 0.3.7
mivanit Apr 6, 2023
808e333
updated poetry.lock
mivanit Apr 6, 2023
04b9d09
prelim to/from ascii and pixels methods, might need to be moved
mivanit Apr 6, 2023
9ab36f7
run notebook
mivanit Apr 6, 2023
4548296
merge with add-notebook-testing
mivanit Apr 9, 2023
377724a
wip
mivanit Apr 9, 2023
2406dea
wip
mivanit Apr 9, 2023
70e99f5
this was some of the most paintful debugging ive ever done
mivanit Apr 10, 2023
a8a52af
format
mivanit Apr 10, 2023
8ab6e79
bump muutils
mivanit Apr 10, 2023
6bf592b
merge with main
mivanit Apr 10, 2023
820f0b3
fixes?
mivanit Apr 10, 2023
ecb1872
format
mivanit Apr 10, 2023
b650af9
update poetry lock
mivanit Apr 10, 2023
525c719
fixes
mivanit Apr 10, 2023
93a31aa
format
mivanit Apr 10, 2023
94c675d
reworked mazeplot init
mivanit Apr 10, 2023
e612f09
wip
mivanit Apr 11, 2023
3cf9041
add unit length parameter to MazePlot
canrager Apr 11, 2023
40f4efd
misspelled folder??
mivanit Apr 11, 2023
ea7a66a
wip, but unit tests passing!
mivanit Apr 11, 2023
b09e707
wip
mivanit Apr 12, 2023
e1b774f
incomprehensible upstream issue in muutils
mivanit Apr 12, 2023
e2d3799
reworking training script
mivanit Apr 12, 2023
16b5665
wip
mivanit Apr 12, 2023
c3a9d69
test_train_model working!
mivanit Apr 12, 2023
a8f8934
wip
mivanit Apr 13, 2023
5d8bd00
test_eval_model passing
mivanit Apr 13, 2023
5238158
format
mivanit Apr 13, 2023
56ce56d
wip refactor
mivanit Apr 14, 2023
bb04c45
SolvedMaze now inherits from TargetedLatticeMaze
mivanit Apr 14, 2023
09876b1
Really dumb bug tracked down, path would overwrite endpoints in as_pi…
mivanit Apr 14, 2023
cdb9ea7
format
mivanit Apr 14, 2023
ea20e9a
Merge branch 'add-maze-from-ascii' of https://github.com/AISC-underst…
mivanit Apr 14, 2023
20436ab
remove MazePlot.show()
mivanit Apr 14, 2023
f65abbe
aaaaA
mivanit Apr 15, 2023
134e0ea
wip
mivanit Apr 15, 2023
fe4eae6
merge
mivanit Apr 15, 2023
f248e5a
wip
mivanit Apr 15, 2023
22518df
wip filtering
mivanit Apr 15, 2023
2c0728e
more filtering wip
mivanit Apr 15, 2023
360c940
wip filters
mivanit Apr 15, 2023
1ae7d6e
filters working!
mivanit Apr 15, 2023
1742ee4
filteringgit add maze_transformer/ notebooks/!
mivanit Apr 15, 2023
41223af
removed debug printing
mivanit Apr 15, 2023
52c2042
format
mivanit Apr 15, 2023
92eae14
simplified decorator, minor change to notebook
mivanit Apr 15, 2023
2180d19
filtering improvements
mivanit Apr 16, 2023
f1e304c
format
mivanit Apr 16, 2023
cee6204
bump muutils to v0.3.9
mivanit Apr 18, 2023
f491f32
Add tests for MazeDataset
valedan Apr 19, 2023
e64119e
Test custom filters
valedan Apr 19, 2023
a8fd1e5
test dataset filters
valedan Apr 19, 2023
a7148e9
fixed minor bugs in tests from zanj-integration-datasets, needs to be…
mivanit Apr 20, 2023
da56b52
initial version of maze complexity evals
mivanit Apr 20, 2023
2c13e51
fixed bug in cut_percentile_shortest and ran formatting
mivanit Apr 20, 2023
a510d41
merging in from main
mivanit Apr 20, 2023
990dbb0
format, resolved a forgotten merge conflict
mivanit Apr 20, 2023
2d91858
MazePath dissapeared again???
mivanit Apr 20, 2023
45e75dd
format (removed jaxtyping import)
mivanit Apr 20, 2023
135435a
added a TODO of something to implement for constrained dfs kwargs
mivanit Apr 25, 2023
88002f6
dumb bug that probably doesnt matter since we will remove TargetedLat…
mivanit Apr 26, 2023
e0cd326
Revert "dumb bug that probably doesnt matter since we will remove Tar…
mivanit Apr 26, 2023
15070b6
Zanj datasets getitem (#182)
valedan Apr 26, 2023
88402cd
format
mivanit Apr 28, 2023
e8b7196
format
mivanit Apr 28, 2023
04486da
Constrained dfs, dataset modifications (#184)
canrager Apr 28, 2023
6d942ef
Merge branch 'zanj-integration-datasets' of https://github.com/AISC-u…
mivanit Apr 28, 2023
54ff5a0
fixed maze dataset config hash usage, removed print from parallel wor…
mivanit Apr 28, 2023
c99f652
format
mivanit Apr 28, 2023
e30f3f0
fixed notebook test
mivanit Apr 28, 2023
e58c348
bumpy pytest to 7.3.1 to resolve missing 'mocker' fixture
mivanit Apr 28, 2023
e2f9039
fix biased baseline
valedan Apr 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def find_config(folder: Path) -> Path | tuple[Path, Path] | None:
def load_model_with_configs(
model_path: Path,
verbose: bool = False,
fold_ln: bool = True,
) -> tuple[HookedTransformer, ConfigHolder]:
"""
Load a model and associated config files from a path.
Expand Down Expand Up @@ -89,7 +90,7 @@ def load_model_with_configs(
# will complain about the fact that we deleted layernorm from the state_dict
# NOTE temporary fix until https://github.com/neelnanda-io/TransformerLens/issues/219 is resolved

model.process_weights_(fold_ln=True)
model.process_weights_(fold_ln=fold_ln)
model.setup() # Re-attach layernorm hooks by calling setup
model.eval()

Expand Down Expand Up @@ -152,8 +153,9 @@ def evaluate_model(
for batch in chunks(dataset.mazes_tokens, batch_size):
# TODO: This won't be needed after #124, then we can call mazes_objs instead
# https://github.com/orgs/AISC-understanding-search/projects/1/views/1?pane=issue&itemId=23879308
solved_mazes = [SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in batch]
mazes, solutions = zip(*solved_mazes)
solved_mazes: SolvedMaze = [
SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in batch
]

predictions = predict_maze_paths(
tokens_batch=batch,
Expand All @@ -166,12 +168,12 @@ def evaluate_model(
for name, func in eval_functions.items():
score_counters[name].update(
func(
maze=maze,
solution=np.array(solution),
maze=sm.maze,
solution=np.array(sm.solution),
prediction=np.array(prediction),
model=model,
)
for maze, solution, prediction in zip(mazes, solutions, predictions)
for sm, prediction in zip(solved_mazes, predictions)
)

return score_counters
34 changes: 18 additions & 16 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from typing import Iterable, Optional, Protocol, TypeAlias
from typing import Iterable, Optional, Protocol

import numpy as np
from jaxtyping import Int

from maze_transformer.generation.constants import Coord, CoordTup
from maze_transformer.generation.constants import Coord, CoordArray, CoordTup
from maze_transformer.generation.lattice_maze import LatticeMaze
from maze_transformer.utils.utils import register_method

# pylint: disable=unused-argument
MazePath: TypeAlias = Int[np.ndarray, "node x_y_pos"]


class PathEvalFunction(Protocol):
def __call__(
self,
maze: Optional[LatticeMaze] = None,
solution: Optional[MazePath] = None,
prediction: Optional[MazePath] = None,
solution: Optional[CoordArray] = None,
prediction: Optional[CoordArray] = None,
) -> float:
...


def path_as_segments_iter(path: MazePath) -> Iterable[tuple]:
def path_as_segments_iter(path: CoordArray) -> Iterable[tuple]:
"""
Iterate over the segments of a path (ie each consecutive pair).
"""
Expand All @@ -44,7 +42,7 @@ class PathEvals:

@register_method(evals)
@staticmethod
def node_overlap(solution: MazePath, prediction: MazePath, **_) -> float:
def node_overlap(solution: CoordArray, prediction: CoordArray, **_) -> float:
"""number of shared nodes (any order) / total number of (unique) nodes in solution"""

solution_set = {tuple(coord) for coord in solution}
Expand All @@ -54,7 +52,7 @@ def node_overlap(solution: MazePath, prediction: MazePath, **_) -> float:

@register_method(evals)
@staticmethod
def num_connections_adjacent_lattice(prediction: MazePath, **_) -> float:
def num_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:
"""number of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze"""
n_adj: float = 0.0
for step_start, step_end in path_as_segments_iter(prediction):
Expand All @@ -65,14 +63,16 @@ def num_connections_adjacent_lattice(prediction: MazePath, **_) -> float:

@register_method(evals)
@staticmethod
def fraction_connections_adjacent_lattice(prediction: MazePath, **_) -> float:
def fraction_connections_adjacent_lattice(prediction: CoordArray, **_) -> float:
"""fraction of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze"""

return PathEvals.num_connections_adjacent_lattice(prediction) / len(prediction)

@register_method(evals)
@staticmethod
def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> float:
def num_connections_adjacent(
maze: LatticeMaze, prediction: CoordArray, **_
) -> float:
"""number of connections in prediction which are are valid paths on the maze"""
n_connected: float = 0.0
for step_start, step_end in path_as_segments_iter(prediction):
Expand All @@ -84,7 +84,7 @@ def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> fl
@register_method(evals)
@staticmethod
def fraction_connections_adjacent(
maze: LatticeMaze, prediction: MazePath, **_
maze: LatticeMaze, prediction: CoordArray, **_
) -> float:
"""fraction of connections in prediction which are are valid paths on the maze"""

Expand All @@ -95,20 +95,22 @@ def fraction_connections_adjacent(

@register_method(evals)
@staticmethod
def exact_path_predicted(solution: MazePath, prediction: MazePath, **_) -> float:
def exact_path_predicted(
solution: CoordArray, prediction: CoordArray, **_
) -> float:
"""Was the maze successfully solved?"""
return float(np.array_equal(solution, prediction))

@register_method(evals)
@staticmethod
def solution_length(solution: MazePath, **_) -> float:
def solution_length(solution: CoordArray, **_) -> float:
return float(len(solution))

@register_method(evals)
@staticmethod
def streak_length_until_incorrect(
solution: MazePath,
prediction: MazePath,
solution: CoordArray,
prediction: CoordArray,
**_,
) -> float:
"""How many moves until the predicted path deviates from the solution"""
Expand Down
84 changes: 44 additions & 40 deletions maze_transformer/evaluation/plot_maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Float
from jaxtyping import Bool, Float
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from muutils.tensor_utils import NDArray

from maze_transformer.generation.constants import Coord, CoordArray, CoordList
from maze_transformer.generation.lattice_maze import Coord, CoordArray, LatticeMaze
from maze_transformer.generation.lattice_maze import (
Coord,
CoordArray,
LatticeMaze,
SolvedMaze,
TargetedLatticeMaze,
)

MAX_NODE_VALUE_EPSILON: float = 1e-10

Expand Down Expand Up @@ -113,16 +119,16 @@ class MazePlot:
"slategrey",
]

def __init__(self, maze: LatticeMaze) -> None:
def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
"""
UNIT_LENGTH: Set ratio between node size and wall thickness in image.
Wall thickness is fixed to 1px
A "unit" consists of a single node and the right and lower connection/wall.
Example: ul = 14 yields 13:1 ratio between node size and wall thickness
"""
self.unit_length: int = 14
self.unit_length: int = unit_length
self.maze: LatticeMaze = maze
self.true_path: StyledPath = None
self.true_path: StyledPath | None = None
self.predicted_paths: list[StyledPath] = []
self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
self.custom_node_value_flag: bool = False
Expand All @@ -131,6 +137,23 @@ def __init__(self, maze: LatticeMaze) -> None:
self.target_token_coord: Coord = None
self.preceding_tokens_coords: CoordArray = None

if isinstance(maze, TargetedLatticeMaze):
self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)

if isinstance(maze, SolvedMaze):
self.add_true_path(maze.solution)

@property
def solved_maze(self) -> SolvedMaze:
if self.true_path is None:
raise ValueError(
"Cannot return SolvedMaze object without true path. Add true path with add_true_path method."
)
return SolvedMaze.from_lattice_maze(
lattice_maze=self.maze,
solution=self.true_path.path,
)

def add_true_path(
self,
path: CoordList | CoordArray | StyledPath,
Expand Down Expand Up @@ -227,12 +250,6 @@ def plot(self, dpi: int = 100, title: str = "") -> MazePlot:
self.ax.set_ylabel("row")
self.fig.suptitle(title)

def show(self, dpi: int = 100, title: str = "") -> None:
"""Plot the maze and paths and show the plot. DONT USE THIS IN NOTEBOOKS WHICH NEED TO BE TESTED IN CI!!!"""
self.plot(dpi=dpi, title=title)
plt.show()
return self

def _rowcol_to_coord(self, point: Coord) -> NDArray:
"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
point = np.array([point[1], point[0]])
Expand Down Expand Up @@ -294,7 +311,10 @@ def _plot_maze(self) -> None:

self.ax.imshow(img, cmap=cmap, vmin=-1, vmax=1)

def _lattice_maze_to_img(self) -> NDArray["row col", bool]:
def _lattice_maze_to_img(
self,
connection_val_scale: float = 0.93,
) -> Bool[np.ndarray, "row col"]:
"""
Build an image to visualise the maze.
Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
Expand All @@ -318,7 +338,7 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]:
# Set node and connection values
if self.node_values is None:
scaled_node_values = np.ones(self.maze.grid_shape)
connection_values = scaled_node_values * 0.93
connection_values = scaled_node_values * connection_val_scale
else:
# Normalizing node colors to match color_map running in (-1, 1) (defined in ._plot_maze()).
scaled_node_values = self.node_values / self.max_node_value
Expand Down Expand Up @@ -403,30 +423,14 @@ def _plot_path(self, path_format: PathFormat) -> None:
ms=10,
)

def as_ascii(self, start=None, end=None):
"""
Returns an ASCII visualization of the maze.
Courtesy of ChatGPT
"""
wall_char = "#"
path_char = " "
self.unit_length = 2

# Determine the size of the maze
maze = self._lattice_maze_to_img()
n_rows, n_cols = maze.shape
maze_str = ""

# Iterate through each element of the maze and print the appropriate symbol
for i in range(n_rows):
for j in range(n_cols):
if start is not None and start[0] == i - 1 and start[1] == j - 1:
maze_str += "S"
elif end is not None and end[0] == i - 1 and end[1] == j - 1:
maze_str += "E"
elif maze[i, j] == -1:
maze_str += wall_char
else:
maze_str += path_char
maze_str += "\n" # Start a new line after each row
return maze_str
def to_ascii(
self,
show_endpoints: bool = True,
show_solution: bool = True,
) -> str:
if self.true_path:
return self.solved_maze.as_ascii(
show_endpoints=show_endpoints, show_solution=show_solution
)
else:
return self.maze.as_ascii(show_endpoints=show_endpoints)
44 changes: 44 additions & 0 deletions maze_transformer/evaluation/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from muutils.zanj.torchutil import ConfigMismatchException, assert_model_cfg_equality

from maze_transformer.training.config import ZanjHookedTransformer


def assert_model_output_equality(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should live in tests/helpers

model_a: ZanjHookedTransformer, model_b: ZanjHookedTransformer
):
try:
assert_model_cfg_equality(model_a, model_b)
except ConfigMismatchException as e:
if e.diff == {
"model_cfg": {"are_weights_processed": {"self": False, "other": True}}
} or e.diff == {
"model_cfg": {
"are_layernorms_folded": {"self": False, "other": True},
"are_weights_processed": {"self": False, "other": True},
}
}:
pass
else:
raise e

# Random input tokens
dataset_cfg = model_a.zanj_model_config.dataset_cfg
input_sequence = torch.randint(
low=0,
high=len(dataset_cfg.token_arr),
size=(1, min(dataset_cfg.seq_len_max, 10)),
)

# (copied from `test_eval_model.py`)
# Check for equality in argsort (absolute values won't be equal due to centering the unembedding weight matrix)
assert torch.all(
model_a(input_sequence.clone()).argsort()
== model_b(input_sequence.clone()).argsort()
)
# apply normalization (e.g. softmax) and check with atol v-small
# (roughly 1E-7 for float error on logexp I think)
output_a = torch.nn.functional.softmax(model_a(input_sequence.clone()), dim=-1)
output_b = torch.nn.functional.softmax(model_b(input_sequence.clone()), dim=-1)

assert torch.allclose(output_a, output_b, atol=1e-7)
10 changes: 5 additions & 5 deletions maze_transformer/generation/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from muutils.tensor_utils import NDArray
from jaxtyping import Int8

Coord = NDArray["x y", np.int8]
Coord = Int8[np.ndarray, "x y"]
CoordTup = tuple[int, int]
CoordArray = NDArray["coords", np.int8]
CoordArray = Int8[np.ndarray, "coord x y"]
CoordList = list[CoordTup]

SPECIAL_TOKENS: dict[str, str] = dict(
Expand All @@ -20,7 +20,7 @@
padding="<PADDING>",
)

DIRECTIONS_MAP: NDArray["direction axes", int] = np.array(
DIRECTIONS_MAP: Int8[np.ndarray, "direction axes"] = np.array(
[
[0, 1], # down
[0, -1], # up
Expand All @@ -30,7 +30,7 @@
)


NEIGHBORS_MASK: NDArray["coord point", int] = np.array(
NEIGHBORS_MASK: Int8[np.ndarray, "coord point"] = np.array(
[
[0, 1], # down
[0, -1], # up
Expand Down
Loading