Skip to content

Commit

Permalink
Add experiments, part 2 (#197)
Browse files Browse the repository at this point in the history
mega-PR, adding a bunch of experiment notebooks and the required code for them.

broad overview:
- added some example models to `examples/`
- reworked eval code, needs big changes -- see #200 
- many modifications to mechinterp code
- had to enforce transformerlens 1.6.1 due to tokenizer changes (it tried to get our custom tokenizer from huggingface?)
- exported some code to muutils
- notebooks added:
  - `eval_tasks_table.ipynb`: evaluate on a bunch of single token tasks. should be merged with other evals notebook
  - `appendix_figures.ipynb`: junk and duplicates of code in other notebooks :/
  - `generate_rollouts.ipynb`: what the name says, simple notebooks

comment history:

* trying to see if wandb model loading is working right

* moved dict shapes to muutils (its on unmerged branch tho)

* better loading of models from wandb

* wip????????????????

* way more testing for loading wandb models

* aaaa

* ???

* hallway run

* update muutils dep to 0.5.3

* updated TL and maze-dataset dep

* type hint

* notebook runs?

* wip runs

* cleared notebooks?

* exported eval plots

* format

* many fixes and changes sorry

* wip

* poetry lock

* minor adjustment to make model names cleaner

* exported single token tasks

* refactored baseline model, allowed return of multiple options

going to be useful for plot_logits

* more baseline model refactor

* format

* dep?

* train_model test was trying to train on 3M samples lol

* seperate appendix figures notebooks, better logits plotting

logits plotting now allows for adding other categories to the histogram besides
correct / incorrect, which we can use the baseline model for

* misc

* rename original hallway model

need to fix refs to it later lol

* WE'RE SO BACK, ADJACENCY HEADS ARE HERE

check the dla notebook!!!

* correlation of attention and distance

* misc

* ok no more figures for now

* temp notebooks, for experiments. move these to paper repo later

* eval tasks table

* final before unireps submit

* misc fixes??

* added padding functionality and batched predictions

* wip

* wip

* wip

* added attention animation plotter

* format

* update deps

* transformerlens 1.6.1 due to issues :/

* cleaning up notebooks

latest versions of some were in experiments repo

* fix up some notebooks, eval_model is still broken

* providing hallway model

* fixing eval_model issues with baseline solver

batching was not working at all, had to add a hack to recursively
call .generate() on RandomBaseline

return type was list[str] instead of tensor or list[list[str]] so
had to fix that as well

* update dep to muutils 0.5.5 (poetry not recognizing it yet)

* format

* poetry lock

* changed model used to hallway

* changed model paths, no jirpy

* update embedding structure nb

* updated plot attention for better cbar

* fix up eval tasks table notebook

* fix when cbar is none

* ran notebook
  • Loading branch information
mivanit authored Dec 8, 2023
1 parent 56947c6 commit 10eac86
Show file tree
Hide file tree
Showing 32 changed files with 5,842 additions and 2,480 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ wandb
tests/_temp/**
tests/**/_temp/**
notebooks/data/**
notebooks/plots/**

.coverage
htmlcov/
Expand Down
File renamed without changes.
Binary file added examples/model.hallway-jvq.final.zanj
Binary file not shown.
148 changes: 121 additions & 27 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from maze_transformer.training.config import ConfigHolder


class InvalidTaskForRandomBaselineError(Exception):
"""the task is not coordinate prediction, and is not valid for a random baseline "model" """

pass


class RandomBaseline(HookedTransformer):
"""
A model that chooses valid paths and never backtracks, but makes random decisions at forks.
Expand All @@ -48,21 +54,27 @@ def _get_coord_neighbors(
# This conversion won't be needed after https://github.com/understanding-search/maze-transformer/issues/154
return [tuple(arr.tolist()) for arr in neighbors]

def _predict_next_step(
def _get_all_valid_next_steps(
self,
solved_maze: SolvedMaze,
target: CoordTup,
path: list[CoordTup],
path: list[CoordTup] | None = None,
pad_eos: bool = False,
) -> CoordTup | str:
"""returns a tuple coordinate or a special token"""
) -> tuple[CoordTup | str | None, list[CoordTup | str]]:
"""returns a tuple of (correct_step, incorrect_steps)"""

if path is None or len(path) == 0:
return (tuple(solved_maze.start_pos), [])

path_end_return: tuple[str, list[str]] = (SPECIAL_TOKENS.PATH_END, [])

current_position: CoordTup = path[-1]
# pad with eos up to max_new_tokens to avoid ragged tensors
if pad_eos:
if current_position in [target, SPECIAL_TOKENS.PATH_END]:
return SPECIAL_TOKENS.PATH_END
return path_end_return
if current_position == target:
return SPECIAL_TOKENS.PATH_END
return path_end_return

neighbors: list[CoordTup] = self._get_coord_neighbors(
solved_maze, current_position
Expand All @@ -80,29 +92,60 @@ def _predict_next_step(

if len(unvisited_neighbors) == 0:
# break out if dead end
return SPECIAL_TOKENS.PATH_END
return path_end_return
else:
if correct_step not in unvisited_neighbors:
return random.choice(unvisited_neighbors)
return (None, unvisited_neighbors)

incorrect_steps = unvisited_neighbors[:]
incorrect_steps = unvisited_neighbors[:] # what is this doing?
incorrect_steps.remove(correct_step)

prob_of_incorrect = (len(incorrect_steps) / len(unvisited_neighbors)) * (
1 - self.bias
)
return (correct_step, incorrect_steps)

will_choose_correctly = random.random() > prob_of_incorrect
if will_choose_correctly:
return correct_step
else:
return random.choice(incorrect_steps)
def _predict_next_step(
self,
solved_maze: SolvedMaze,
target: CoordTup,
path: list[CoordTup],
pad_eos: bool = False,
) -> CoordTup | str:
"""returns a tuple coordinate or a special token"""

def _generate_path(
correct_step: CoordTup | str | None
incorrect_steps: list[CoordTup | str]
correct_step, incorrect_steps = self._get_all_valid_next_steps(
solved_maze=solved_maze,
target=target,
path=path,
pad_eos=pad_eos,
)

# if only one option, return that
if len(incorrect_steps) == 0:
assert correct_step is not None
return correct_step

# if no correct choice (no backtracking, towards target), return random choice
if correct_step is None:
assert len(incorrect_steps) > 0
return random.choice(incorrect_steps)

# if there is a correct choice, choose randomly between correct and incorrect
n_unvisited_neighbors: int = len(incorrect_steps) + 1
prob_of_incorrect = (len(incorrect_steps) / n_unvisited_neighbors) * (
1 - self.bias
)

will_choose_correctly = random.random() > prob_of_incorrect
if will_choose_correctly:
return correct_step
else:
return random.choice(incorrect_steps)

def _tokens_to_maze_and_path(
self,
tokens: list[str],
steps_to_predict: int,
) -> list[str]:
) -> tuple[SolvedMaze, list[Coord]]:
# assemble the maze from the tokens
maze: LatticeMaze = LatticeMaze.from_tokens(
tokens, self.tokenizer._maze_tokenizer
Expand All @@ -120,6 +163,19 @@ def _generate_path(
when_noncoord="except",
)

return (solved_maze, context_existing_path)

def _generate_path(
self,
tokens: list[str],
steps_to_predict: int,
) -> list[str]:
solved_maze: SolvedMaze
context_existing_path: list[Coord]
solved_maze, context_existing_path = self._tokens_to_maze_and_path(tokens)
origin_coord: CoordTup = tuple(solved_maze.start_pos.tolist())
target_coord: CoordTup = tuple(solved_maze.end_pos.tolist())

# assemble our predicted path
predictions: list[Coord] = list()

Expand All @@ -143,18 +199,15 @@ def _generate_path(
predictions, when_noncoord="include"
)

def generate(
def _process_context(
self,
context: str | list[str] | Float[torch.Tensor, "pos"],
max_new_tokens: int,
**_,
) -> str:
# convert input to a list of tokens
) -> list[str]:
tokens: list[str]
if isinstance(context, torch.Tensor):
tokens = self.to_str_tokens(context)
tokens = self.to_str_tokens(context, prepend_bos=False)
elif isinstance(context, list):
if all(isinstance(x, str) for x in tokens):
if all(isinstance(x, str) for x in context):
tokens = context
else:
raise TypeError(
Expand All @@ -165,6 +218,28 @@ def generate(
else:
raise TypeError(f"Expected list[str], str, or tensor, got {type(context)}")

return tokens

def generate(
self,
context: str | list[str] | Float[torch.Tensor, "pos"],
max_new_tokens: int,
**_,
) -> str:
# hack for more than one batch
if isinstance(context, torch.Tensor):
if context.ndim == 2:
return [
self.generate(
context[i],
max_new_tokens,
)
for i in range(context.shape[0])
]

# convert input to a list of tokens
tokens: list[str] = self._process_context(context)

# generate path
generated_path: list[str] = self._generate_path(
tokens,
Expand All @@ -180,3 +255,22 @@ def generate(

# output: Float[torch.Tensor, "batch pos_plus_new_tokens"] = self.tokenizer(solved_maze, is_split_into_words=True)["input_ids"]
return output

def get_valid_next_steps(
self,
context: str | list[str] | Float[torch.Tensor, "pos"],
) -> tuple[CoordTup | str | None, list[CoordTup | str]]:
# convert input to a list of tokens
tokens: list[str] = self._process_context(context)

# get maze and path
solved_maze: SolvedMaze
context_existing_path: list[Coord]
solved_maze, context_existing_path = self._tokens_to_maze_and_path(tokens)

# get valid next steps
return self._get_all_valid_next_steps(
solved_maze=solved_maze,
target=tuple(solved_maze.end_pos.tolist()),
path=context_existing_path,
)
Loading

0 comments on commit 10eac86

Please sign in to comment.