Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support maze dataset tokenizers update #214

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
0ff866e
Add check on <UNK> token for `maze-dataset` update
aaron-sandoval Apr 18, 2024
ac2c09c
Update dependencies, including `maze-dataset = "^1.0.0"`
aaron-sandoval Apr 25, 2024
a52baca
maze-dataset PR #37 moved token_utils.py and util.py to a different d…
aaron-sandoval May 17, 2024
404e49f
Update mostly just type hints for `MazeTokenizerModular`. No updates …
aaron-sandoval Jun 29, 2024
8424609
Updated unit tests to incorporate `MazeTokenizerModular`. Not run yet
aaron-sandoval Jun 29, 2024
58b6ef5
made a comment
aaron-sandoval Jul 22, 2024
283bdd0
wip making a make recipe to run tests with a user-provided branch of …
aaron-sandoval Jul 22, 2024
eab11f4
wip, bogged down in Windows vs Linux crap
aaron-sandoval Jul 22, 2024
f83e169
wip, still stuck
aaron-sandoval Jul 22, 2024
c0a48a2
m-d git branch environment specified in maze-dataset_test directory
aaron-sandoval Jul 26, 2024
d34e6f5
Environment was broken in subdirectory. Move it to the main environment
aaron-sandoval Jul 26, 2024
f4f9303
Merge branch 'update-maze-dataset-tokenizers-step2' into add-maze-dat…
aaron-sandoval Jul 26, 2024
667c6ce
Merge pull request #216 from understanding-search/add-maze-dataset-br…
aaron-sandoval Jul 26, 2024
dc30b32
Merge branch 'update-maze-dataset-tokenizers-step2' of https://github…
aaron-sandoval Jul 26, 2024
6711a7e
Small edits to get unit tests to collect
aaron-sandoval Jul 26, 2024
fb0b1da
Merge branch 'main' into update-maze-dataset-tokenizers-step2
mivanit Jul 26, 2024
02537b8
bump maze-dataset
mivanit Jul 26, 2024
1b39086
run format
mivanit Jul 26, 2024
ec748ef
fix imports, unit tests collect
mivanit Jul 26, 2024
74df787
upstream mmtokenizer summary() fix
mivanit Jul 26, 2024
34e9e74
?????????
mivanit Jul 26, 2024
24ca2c4
run format
mivanit Jul 26, 2024
7138677
legacy mt was loaded as mmt by mistake
mivanit Jul 26, 2024
6c3cce2
re-run nb
mivanit Jul 26, 2024
3c244b0
fix loading maze tokenizers
mivanit Jul 26, 2024
ced372d
update dep
mivanit Jul 26, 2024
af3f953
`test_tokenization_encoding` passing
aaron-sandoval Aug 1, 2024
e2f94ac
`test_tokenizer_inside_hooked_transformer` passing
aaron-sandoval Aug 1, 2024
0ddfa23
`test_cfg_post_init` passing
aaron-sandoval Aug 1, 2024
fa825c0
Everything in `test_config_holder.py` passing
aaron-sandoval Aug 1, 2024
2c1d19e
`test_random_baseline` passing. 2 zanj tests are the only ones still …
aaron-sandoval Aug 1, 2024
6e585d3
format
aaron-sandoval Aug 1, 2024
be99a06
zanj save load tests with multiple tokenizers
mivanit Aug 13, 2024
4157519
Merge branch 'main' into update-maze-dataset-tokenizers-step2
mivanit Aug 20, 2024
9e7b888
poetry update
mivanit Aug 20, 2024
3a589b1
fix failing model loading tests
mivanit Aug 20, 2024
3956867
integration test where too many vocab elements caused argsort fail
mivanit Aug 20, 2024
0650190
trained new demo model
mivanit Aug 20, 2024
abc3dcc
replaced demo model path in tests, chnaged notebook cfg to test
mivanit Aug 20, 2024
8311eb8
move training tests to test_train_model.py
mivanit Aug 20, 2024
0064898
trying to fix pytest hang issue by closing wandb run
mivanit Aug 20, 2024
7cf67a7
return logger in TrainingResult object from train_model
mivanit Aug 21, 2024
fc42faa
update maze-dataset, new version should maybe fix wandb issues?
mivanit Aug 21, 2024
104e004
re-run notebook to get new model with fixed keys
mivanit Aug 21, 2024
3e254b2
changed cfg back to test in train nb, re-run
mivanit Aug 21, 2024
9f15e7e
format
mivanit Aug 21, 2024
4df4a24
update maze-dataset dep
mivanit Aug 22, 2024
4b3e8ca
ok this bug is incomprehensible
mivanit Aug 22, 2024
98a2083
fixed bug - passing configs passed by ref and modified
mivanit Aug 22, 2024
85ee878
format
mivanit Aug 22, 2024
1f46532
fix paths in notebooks
mivanit Aug 22, 2024
3f795f7
remove old stuff from makefile
mivanit Aug 22, 2024
096f6df
update dep
mivanit Aug 23, 2024
97fba7b
update dep
mivanit Aug 27, 2024
0d05191
update dep
mivanit Aug 27, 2024
04039a4
update deps??
mivanit Aug 27, 2024
31442ec
update dep
mivanit Aug 27, 2024
16e60e4
update dep to maze-dataset 1.0.0
mivanit Aug 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ tests/_temp/**
tests/**/_temp/**
notebooks/data/**
notebooks/plots/**
# maze-dataset_test/**

.coverage
htmlcov/
Expand Down
6 changes: 6 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ clean:
python -Bc "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('__pycache__')]"


.PHONY: test_with_branch
test_with_branch:
@echo "creating test environment"
cp pyproject.toml poetry.lock maze-dataset_test


# listing targets, from stackoverflow
# https://stackoverflow.com/questions/4219255/how-do-you-get-the-list-of-targets-in-a-makefile
.PHONY: help
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LatticeMaze,
SolvedMaze,
)
from maze_dataset.tokenization.token_utils import (
from maze_dataset.token_utils import (
get_origin_tokens,
get_path_tokens,
get_target_tokens,
Expand Down
7 changes: 4 additions & 3 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
MazeDatasetConfig,
SolvedMaze,
)
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import (
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from maze_dataset.token_utils import (
WhenMissing,
get_context_tokens,
get_path_tokens,
remove_padding_from_token_str,
Expand Down Expand Up @@ -143,7 +144,7 @@ def predict_maze_paths(
smart_max_new_tokens
), "if max_new_tokens is None, smart_max_new_tokens must be True"

maze_tokenizer: MazeTokenizer = model.tokenizer._maze_tokenizer
maze_tokenizer: MazeTokenizer | MazeTokenizerModular = model.config.maze_tokenizer

contexts_lists: list[list[str]] = [
get_context_tokens(tokens) for tokens in tokens_batch
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/evaluation/eval_single_token_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Our Code
# dataset stuff
from maze_dataset import MazeDataset
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from muutils.json_serialize import SerializableDataclass, serializable_dataclass

# TransformerLens imports
Expand Down Expand Up @@ -47,7 +47,7 @@ class TaskEvalResult(SerializableDataclass):

def get_task_prompts_targets(
dataset: MazeDataset,
maze_tokenizer: MazeTokenizer,
maze_tokenizer: MazeTokenizer | MazeTokenizerModular,
tasks: dict[str, DLAProtocolFixed] = LOGIT_ATTRIB_TASKS,
) -> dict[str, TaskPrompt]:
dataset_tokens: list[list[str]] = dataset.as_tokens(
Expand All @@ -63,7 +63,7 @@ def eval_model_task(
task: TaskPrompt,
do_cache: bool = False,
) -> TaskEvalResult:
maze_tokenizer: MazeTokenizer = model.tokenizer._maze_tokenizer
maze_tokenizer: MazeTokenizer | MazeTokenizerModular = model.config.maze_tokenizer

prompts_joined: list[str] = [" ".join(prompt) for prompt in task.prompts]

Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/mechinterp/direct_logit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# maze-datset stuff
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular

# TransformerLens imports
from transformer_lens import ActivationCache
Expand Down Expand Up @@ -226,7 +226,7 @@ def create_report(
# model and tokenizer
if not isinstance(model, ZanjHookedTransformer):
model = ZanjHookedTransformer.read(model)
tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer
tokenizer: MazeTokenizer | MazeTokenizerModular = model.zanj_model_config.maze_tokenizer

# dataset cfg
if dataset_cfg_source is None:
Expand Down
8 changes: 4 additions & 4 deletions maze_transformer/mechinterp/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from maze_dataset.plotting import MazePlot
from maze_dataset.plotting.plot_tokens import plot_colored_text
from maze_dataset.plotting.print_tokens import color_tokens_cmap
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.util import coord_str_to_tuple_noneable
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from maze_dataset.token_utils import coord_str_to_tuple_noneable

# Utilities
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
Expand Down Expand Up @@ -377,7 +377,7 @@ def mazeplot_attention(
def plot_attn_dist_correlation(
tokens_context: list[list[str]],
tokens_dist_to: list[str], # either current or target token for each maze
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
attention: Float[np.ndarray, "n_mazes n_tokens"],
ax: plt.Axes | None = None,
respect_topology: bool = False, # manhattan distance if False
Expand Down Expand Up @@ -480,7 +480,7 @@ def plot_attention_final_token(
prompts: list[list[str]],
targets: list[str],
mazes: list[SolvedMaze],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
n_mazes: int = 5,
last_n_tokens: int = 20,
# exponentiate_scores: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions maze_transformer/mechinterp/plot_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from maze_dataset import CoordTup

# Our Code
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular

_DEFAULT_SUBPLOTS_KWARGS: dict = dict(
figsize=(20, 20),
Expand Down Expand Up @@ -86,7 +86,7 @@ def plot_logit_histograms(

def get_baseline_incorrect_group(
prompts: list[list[str]],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
baseline: "RandomBaseline",
) -> Bool[torch.Tensor, "n_mazes d_vocab"]:
"""
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_baseline_incorrect_group(
def plot_logits(
last_tok_logits: Float[torch.Tensor, "n_mazes d_vocab"],
target_idxs: Int[torch.Tensor, "n_mazes"],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
n_bins: int = 50,
mark_incorrect: bool = False,
mark_correct: bool = True,
Expand Down
12 changes: 6 additions & 6 deletions maze_transformer/mechinterp/residual_stream_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

# maze_dataset
from maze_dataset.constants import _SPECIAL_TOKENS_ABBREVIATIONS
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.util import strings_to_coords
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from maze_dataset.token_utils import strings_to_coords

# scipy
from scipy.spatial.distance import pdist, squareform
Expand Down Expand Up @@ -52,7 +52,7 @@ def coordinate_to_color(
)


def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]:
def process_tokens_for_pca(tokenizer: MazeTokenizer | MazeTokenizerModular) -> list[TokenPlottingInfo]:
tokens_coords: list[str | tuple[int, int]] = strings_to_coords(
tokenizer.token_arr, when_noncoord="include"
)
Expand Down Expand Up @@ -227,7 +227,7 @@ def abs_dot_product(u, v):

def compute_distances_and_correlation(
embedding_matrix: Float[np.ndarray, "d_vocab d_model"],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
embedding_metric: str = "cosine",
coordinate_metric: str = "euclidean",
show: bool = True,
Expand Down Expand Up @@ -277,7 +277,7 @@ def compute_distances_and_correlation(

def plot_distances_matrix(
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
embedding_metric: str,
show: bool = True,
**kwargs,
Expand Down Expand Up @@ -313,7 +313,7 @@ def plot_distances_matrix(

def compute_grid_distances(
embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"],
tokenizer: MazeTokenizer,
tokenizer: MazeTokenizer | MazeTokenizerModular,
) -> Float[np.ndarray, "n n n n"]:
n: int = tokenizer.max_grid_size
grid_distances: Float[np.ndarray, "n n n n"] = np.full((n, n, n, n), np.nan)
Expand Down
11 changes: 6 additions & 5 deletions maze_transformer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from maze_dataset import SPECIAL_TOKENS, LatticeMaze
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from muutils.tensor_utils import ATensor, NDArray
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import BatchEncoding
Expand Down Expand Up @@ -46,13 +46,13 @@ def apply_overrides(self) -> None:
def __init__(
self,
seq_len_max: int,
maze_tokenizer: MazeTokenizer,
maze_tokenizer: MazeTokenizer | MazeTokenizerModular,
**kwargs,
) -> None:
"""extension of PreTrainedTokenizer for mazes. takes maximum sequence length and maze_tokenizer. also, kwargs are passed to super `PreTrainedTokenizer`"""
super().__init__(max_len=seq_len_max, **kwargs)

self._maze_tokenizer: MazeTokenizer = maze_tokenizer
self._maze_tokenizer: MazeTokenizer | MazeTokenizerModular = maze_tokenizer
token_arr: list[str] = maze_tokenizer.token_arr
self._token_arr: list[str] = token_arr
self._seq_len_max: int = seq_len_max
Expand Down Expand Up @@ -81,8 +81,9 @@ def __init__(

# We are having to do evil things here
vocab: dict[str, int] = {token: i for i, token in enumerate(token_arr)}
vocab[self.unk_token] = len(vocab)
self.vocab: dict[str, int] = vocab
if self.unk_token not in vocab: # maze-dataset ^1.0.0 includes <UNK> already
vocab[self.unk_token] = len(vocab)
self.vocab: dict[str, int] = vocab

special_tokens = list(SPECIAL_TOKENS.values())
normal_tokens = [x for x in token_arr if x not in special_tokens]
Expand Down
25 changes: 10 additions & 15 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS
from maze_dataset.dataset.dataset import GPTDatasetConfig
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode, MazeTokenizerModular
from muutils.dictmagic import kwargs_to_nested_dict
from muutils.json_serialize import (
JSONitem,
Expand Down Expand Up @@ -370,17 +370,14 @@ def summary(self) -> dict:
}


def _load_maze_tokenizer(data: dict) -> MazeTokenizer:
def _load_maze_tokenizer(data: dict) -> MazeTokenizerModular:
"""load the maze tokenizer, including vocab size from a legacy config"""
if "maze_tokenizer" in data:
# new style tokenizer
return load_item_recursive(data["maze_tokenizer"], path=tuple("maze_tokenizer"))
return MazeTokenizerModular.from_legacy(load_item_recursive(data["maze_tokenizer"], path=tuple("maze_tokenizer")))
else:
if "token_arr" in data["dataset_cfg"]:
output: MazeTokenizer = MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_rasterized,
max_grid_size=None,
)
output: MazeTokenizerModular = MazeTokenizerModular()
else:
raise ValueError("Could not find vocab size in legacy config")

Expand All @@ -405,7 +402,7 @@ class ConfigHolder(SerializableDataclass):
pretrainedtokenizer_kwargs: dict[str, JSONitem] | None = serializable_field(
default=None
)
maze_tokenizer: MazeTokenizer | None = serializable_field(
maze_tokenizer: MazeTokenizer | MazeTokenizerModular | None = serializable_field(
default_factory=lambda: None,
loading_fn=_load_maze_tokenizer,
)
Expand Down Expand Up @@ -434,24 +431,22 @@ def n_heads(self) -> int:
return self.model_cfg.n_heads

def _set_tok_gridsize_from_dataset(self):
self.maze_tokenizer.max_grid_size = self.dataset_cfg.max_grid_n
self.maze_tokenizer.clear_cache()
if isinstance(self.maze_tokenizer, MazeTokenizer):
self.maze_tokenizer.max_grid_size = self.dataset_cfg.max_grid_n
self.maze_tokenizer.clear_cache()

def __post_init__(self):
# fallback to default maze tokenizer if no kwargs are provided
if self.pretrainedtokenizer_kwargs is None:
if self.maze_tokenizer is None:
# TODO: is this the right default? maybe set it to AOTP_UT_rasterized
# since thats what legacy models are likely to be?
self.maze_tokenizer = MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_uniform,
max_grid_size=None,
)
self.maze_tokenizer = MazeTokenizerModular()

# update the config of the maze tokenizer if there is no grid size
# since we need the token array for the vocab size of the model
if self.maze_tokenizer is not None:
if self.maze_tokenizer.max_grid_size is None:
if getattr(self.maze_tokenizer, "max_grid_size", None) is None:
self._set_tok_gridsize_from_dataset()

def summary(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion maze_transformer/training/train_save_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from maze_transformer.training.config import ConfigHolder


@freeze
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like freeze can only act on objects, not types. It raises an exception when it recursively calls freeze(TRAIN_SAVE_FILES.__dict__), since it can't act on a mappingproxy.

# @freeze
class TRAIN_SAVE_FILES:
"""namespace for filenames/formats for saving training data"""

Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from jaxtyping import Float
from maze_dataset import MazeDataset, SolvedMaze
from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
from muutils.statcounter import StatCounter
from torch.utils.data import DataLoader
from transformer_lens.HookedTransformer import SingleLoss
Expand All @@ -19,7 +19,7 @@
from maze_transformer.training.wandb_logger import WandbLogger


def collate_batch(batch: list[SolvedMaze], maze_tokenizer: MazeTokenizer) -> list[str]:
def collate_batch(batch: list[SolvedMaze], maze_tokenizer: MazeTokenizer | MazeTokenizerModular) -> list[str]:
return [" ".join(maze.as_tokens(maze_tokenizer)) for maze in batch]


Expand Down
4 changes: 2 additions & 2 deletions notebooks/appendix_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"# dataset stuff\n",
"from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators\n",
"from maze_dataset.plotting import MazePlot, PathFormat\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode, MazeTokenizerModular\n",
"from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP\n",
"\n",
"# model stuff\n",
Expand Down Expand Up @@ -139,7 +139,7 @@
"\tfig.savefig(plot_dir / \"rollouts.pdf\", bbox_inches=\"tight\")\n",
"\tplt.show()\n",
"\n",
"\ttokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer\n",
"\ttokenizer: MazeTokenizer | MazeTokenizerModular = model.zanj_model_config.maze_tokenizer\n",
"\ttask_prompts_targets: dict[str, TaskPrompt] = get_task_prompts_targets(\n",
"\t\tdataset=dataset,\n",
"\t\tmaze_tokenizer=tokenizer,\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/demo_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@
"source": [
"\n",
"from maze_dataset.plotting import MazePlot\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode, MazeTokenizerModular\n",
"from maze_dataset.plotting.print_tokens import display_color_maze_tokens_AOTP, color_maze_tokens_AOTP\n",
"\n",
"maze: SolvedMaze = dataset[0]\n",
Expand All @@ -549,7 +549,7 @@
"# as tokens\n",
"\n",
"# first, initialize a tokenizer -- more about this in the `notebooks/demo_tokenization.ipynb` notebook\n",
"tokenizer: MazeTokenizer = MazeTokenizer(tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=100)\n",
"tokenizer: MazeTokenizerModular = MazeTokenizerModular()\n",
"maze_tok = maze.as_tokens(maze_tokenizer=tokenizer)\n",
"\n",
"# you can view the tokens directly\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/direct_logit_attribution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"# Our Code\n",
"# dataset stuff\n",
"from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS, LatticeMazeGenerators\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode\n",
"from maze_dataset.tokenization import MazeTokenizer, TokenizationMode, MazeTokenizerModular\n",
"from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP\n",
"\n",
"# model stuff\n",
Expand Down Expand Up @@ -287,7 +287,7 @@
}
],
"source": [
"TOKENIZER: MazeTokenizer = MODEL.zanj_model_config.maze_tokenizer\n",
"TOKENIZER: MazeTokenizer | MazeTokenizerModular = MODEL.zanj_model_config.maze_tokenizer\n",
"DATASET_TOKENS: list[list[str]] = DATASET.as_tokens(TOKENIZER, join_tokens_individual_maze=False)\n",
"\n",
"# print some info\n",
Expand Down
Loading
Loading