Skip to content

Commit

Permalink
Refactor tokenization (#191)
Browse files Browse the repository at this point in the history
Refactor to be compatible with `maze-dataset` versions `0.2.1` and onwards. 

See PRs:
- [`maze_dataset` PR #5](understanding-search/maze-dataset#5)
- [`maze_dataset` PR #6](understanding-search/maze-dataset#6)

See related issues:
- #164 
- #163 
- #77 

These changes also revert changes in #118, to be consistent with underscores only appearing once in the special tokens.

# commit history:

* test_cfg_post_init working

* migrated SPECIAL_TOKENS usage

* wip

* wip

* wip, all but 3 in tok tests passing

* test_tokenizers passing

* unit tests passing (but need to update maze_dataset dep)

* poetry lock

* format

* remove deprecated kwarg to process_weights_

Upgrading transformer_lens to 1.4.0 caused
`HookedTransformer.process_weights_()` to no longer accept
the keyword arg `move_state_dict_to_device`

However, I'm not sure if this was important in the first place.
If any issues come up, move the state dict to device manually in
`ZanjHookedTransformer._load_state_dict_wrapper()` where all this
was happening in the first place

* fixed MazeTokenizer not being passed to as_tokens() in some spots

* updated changed dataset config key

since we removed tokenizer stuff from the dataset

* fixed eval_model nb, added ZanjHookedTransformer.config ref

the `eval_model.ipynb` notebook has a function `testdata_plot_predicted_path`
which was using `model.zanj_model_config` to get the tokenizer, an attribute
missing from the `RandomBaseline` class since it only inherits from `HookedTransformer`

to fix this:

- `ZanjHookedTransformer` now has a `config` property which just
  accesses the `zanj_model_config` used by the parent `ConfiguredModel`
- `testdata_plot_predicted_path` now uses `model.config` everywhere

* lock after update maze-dataset to 0.2.1

* fixed minor import issue

* update configs refs in train_model notebook

* lock poetry, re-run notebook

* format

* update coverage
  • Loading branch information
mivanit authored Aug 6, 2023
1 parent d905cc0 commit 5452d13
Show file tree
Hide file tree
Showing 27 changed files with 1,709 additions and 1,553 deletions.
4 changes: 2 additions & 2 deletions examples/coverage/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 19 additions & 18 deletions examples/coverage/coverage.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,41 @@ Name Stmts
-------------------------------------------------------------------------------------------------------
maze_transformer\__init__.py 0 0 100%
maze_transformer\evaluation\__init__.py 0 0 100%
maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 158, 160-163, 169
maze_transformer\evaluation\eval_model.py 78 28 64% 46-55, 70-105, 173-186, 251
maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 155, 157-160, 166
maze_transformer\evaluation\eval_model.py 78 28 64% 46-55, 70-105, 172-185, 252
maze_transformer\evaluation\maze_complexity_evals.py 8 0 100%
maze_transformer\evaluation\path_evals.py 106 9 92% 19, 64-68, 88, 91-95, 167, 177, 190
maze_transformer\evaluation\plot_attention.py 89 89 0% 2-261
maze_transformer\evaluation\path_evals.py 106 4 96% 19, 64-68, 88
maze_transformer\evaluation\plot_attention.py 89 89 0% 2-263
maze_transformer\test_helpers\assertions.py 22 1 95% 44
maze_transformer\test_helpers\stub_logger.py 26 2 92% 23, 26
maze_transformer\tokenizer.py 67 11 84% 13, 50-51, 57, 89-90, 96-98, 101-104
maze_transformer\test_helpers\stub_logger.py 28 2 93% 23, 26
maze_transformer\tokenizer.py 68 8 88% 13, 90-91, 97-99, 102-105
maze_transformer\training\__init__.py 0 0 100%
maze_transformer\training\config.py 156 17 89% 80-83, 143, 190, 194, 232, 402, 424-428, 470-471, 484, 493, 499-502, 548, 553
maze_transformer\training\train_model.py 59 6 90% 34, 61-62, 131-138
maze_transformer\training\config.py 179 22 88% 81-84, 144, 191, 195, 233, 377-383, 444, 457, 471-475, 517-518, 531, 540, 546-549, 573, 599, 604
maze_transformer\training\train_model.py 59 6 90% 34, 61-62, 138-145
maze_transformer\training\train_save_files.py 16 0 100%
maze_transformer\training\training.py 80 3 96% 61, 86-89
maze_transformer\training\training.py 86 6 93% 30, 40-41, 73, 98-101
maze_transformer\training\wandb_logger.py 51 4 92% 56-58, 61
setup.py 3 3 0% 3-6
tests\conftest.py 8 0 100%
tests\integration\test_create_dataset.py 27 0 100%
tests\integration\test_eval_model.py 49 0 100%
tests\integration\test_train_model.py 8 0 100%
tests\integration\test_eval_model.py 52 1 98% 95
tests\integration\test_train_model.py 9 0 100%
tests\integration\test_training.py 62 0 100%
tests\unit\maze_transformer\evaluation\test_baseline_models.py 23 0 100%
tests\unit\maze_transformer\evaluation\test_baseline_models.py 24 0 100%
tests\unit\maze_transformer\evaluation\test_maze_complexity_evals.py 15 0 100%
tests\unit\maze_transformer\evaluation\test_path_evals.py 48 2 96% 30-32
tests\unit\maze_transformer\test_tokenizers.py 58 0 100%
tests\unit\maze_transformer\test_tokenizers.py 71 0 100%
tests\unit\maze_transformer\training\config\test_base_gpt_config.py 29 0 100%
tests\unit\maze_transformer\training\config\test_cfg_post_init.py 9 0 100%
tests\unit\maze_transformer\training\config\test_cfg_save_load.py 67 0 100%
tests\unit\maze_transformer\training\config\test_config_holder.py 38 5 87% 44-50
tests\unit\maze_transformer\training\config\test_train_cfg_intervals.py 89 0 100%
tests\unit\maze_transformer\training\config\test_train_config.py 33 0 100%
tests\unit\maze_transformer\training\test_dataset.py 53 7 87% 44-59
tests\unit\maze_transformer\training\test_get_dataloader.py 18 0 100%
tests\unit\maze_transformer\training\test_maze_dataset_construction.py 8 0 100%
tests\unit\maze_transformer\training\test_dataset.py 49 7 86% 38-53
tests\unit\maze_transformer\training\test_get_dataloader.py 21 0 100%
tests\unit\maze_transformer\training\test_maze_dataset_construction.py 12 0 100%
tests\unit\maze_transformer\training\test_model_loading_old.py 19 0 100%
tests\unit\maze_transformer\training\test_tokenizer.py 13 0 100%
tests\unit\maze_transformer\training\test_tokenizer.py 17 0 100%
tests\unit\maze_transformer\training\zanj\test_zanj_ht_save_load.py 37 0 100%
-------------------------------------------------------------------------------------------------------
TOTAL 1532 194 87%
TOTAL 1598 195 88%
39 changes: 18 additions & 21 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
SolvedMaze,
)
from maze_dataset.tokenization.token_utils import (
coords_to_tokens,
get_origin_token,
get_origin_tokens,
get_path_tokens,
get_target_token,
tokens_to_coords,
get_target_tokens,
strings_to_coords,
)
from transformer_lens import HookedTransformer

Expand Down Expand Up @@ -56,13 +55,14 @@ def _predict_next_step(
path: list[CoordTup],
pad_eos: bool = False,
) -> CoordTup | str:
"""returns a tuple coordinate or a special token"""
current_position: CoordTup = path[-1]
# pad with eos up to max_new_tokens to avoid ragged tensors
if pad_eos:
if current_position in [target, SPECIAL_TOKENS["path_end"]]:
return SPECIAL_TOKENS["path_end"]
if current_position in [target, SPECIAL_TOKENS.PATH_END]:
return SPECIAL_TOKENS.PATH_END
if current_position == target:
return SPECIAL_TOKENS["path_end"]
return SPECIAL_TOKENS.PATH_END

neighbors: list[CoordTup] = self._get_coord_neighbors(
solved_maze, current_position
Expand All @@ -80,7 +80,7 @@ def _predict_next_step(

if len(unvisited_neighbors) == 0:
# break out if dead end
return SPECIAL_TOKENS["path_end"]
return SPECIAL_TOKENS.PATH_END
else:
if correct_step not in unvisited_neighbors:
return random.choice(unvisited_neighbors)
Expand All @@ -104,22 +104,19 @@ def _generate_path(
steps_to_predict: int,
) -> list[str]:
# assemble the maze from the tokens
maze: LatticeMaze = LatticeMaze.from_tokens(tokens)
origin_coord: CoordTup = self.config.dataset_cfg.token_node_map[
get_origin_token(tokens)
]
target_coord: CoordTup = self.config.dataset_cfg.token_node_map[
get_target_token(tokens)
]
maze: LatticeMaze = LatticeMaze.from_tokens(
tokens, self.tokenizer._maze_tokenizer
)
origin_coord: CoordTup = strings_to_coords(get_origin_tokens(tokens))[0]
target_coord: CoordTup = strings_to_coords(get_target_tokens(tokens))[0]
solution: CoordArray = maze.find_shortest_path(origin_coord, target_coord)
solved_maze: SolvedMaze = SolvedMaze.from_lattice_maze(maze, solution)
assert (solved_maze.start_pos == np.array(origin_coord)).all()
assert (solved_maze.end_pos == np.array(target_coord)).all()

# get the path so far
context_existing_path: list[Coord] = tokens_to_coords(
tokens=get_path_tokens(tokens, trim_end=True),
maze_data_cfg=self.config.dataset_cfg,
context_existing_path: list[Coord] = strings_to_coords(
get_path_tokens(tokens, trim_end=True),
when_noncoord="except",
)

Expand All @@ -139,11 +136,11 @@ def _generate_path(
path=path,
)
)
if predictions[-1] == SPECIAL_TOKENS["path_end"]:
if predictions[-1] == SPECIAL_TOKENS.PATH_END:
break

return coords_to_tokens(
predictions, self.config.dataset_cfg, when_noncoord="include"
return self.tokenizer._maze_tokenizer.coords_to_strings(
predictions, when_noncoord="include"
)

def generate(
Expand Down
15 changes: 7 additions & 8 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_context_tokens,
get_path_tokens,
remove_padding_from_token_str,
tokens_to_coords,
strings_to_coords,
)
from muutils.mlutils import chunks
from muutils.statcounter import StatCounter
Expand Down Expand Up @@ -136,7 +136,7 @@ def predict_maze_paths(
# predict tokens
prediction: str = model.generate(
context,
eos_token_id=data_cfg.tokenizer_map[SPECIAL_TOKENS["path_end"]],
eos_token_id=model.tokenizer._tokenizer_map[SPECIAL_TOKENS.PATH_END],
stop_at_eos=True,
max_new_tokens=max_new_tokens,
verbose=verbose,
Expand All @@ -153,9 +153,8 @@ def predict_maze_paths(
paths: list[list[tuple[int, int]]] = []
for pred_tokens in prediction_batch:
path_tokens: list[str] = get_path_tokens(pred_tokens, trim_end=True)
path_coords: list[str | CoordTup] = tokens_to_coords(
path_coords: list[str | CoordTup] = strings_to_coords(
path_tokens,
maze_data_cfg=data_cfg,
when_noncoord=when_noncoord,
)
# This is the correct type when using "skip"
Expand Down Expand Up @@ -209,7 +208,9 @@ def evaluate_model(
}

if dataset_tokens is None:
dataset_tokens = dataset.as_tokens(join_tokens_individual_maze=False)
dataset_tokens = dataset.as_tokens(
model.zanj_model_config.maze_tokenizer, join_tokens_individual_maze=False
)
else:
assert len(dataset) == len(
dataset_tokens
Expand Down Expand Up @@ -262,9 +263,7 @@ def evaluate_logits(
for tokens in prediction_tokens:
# this returns first path_start to end of list. Early in training there may be multiple path_start tokens, so results should be treated with caution
path_tokens = get_path_tokens(tokens.split(" "))
path_coords = tokens_to_coords(
path_tokens, maze_data_cfg=config.dataset_cfg, when_noncoord="skip"
)
path_coords = strings_to_coords(path_tokens, when_noncoord="skip")
predicted_paths.append(cast(list[tuple[int, int]], path_coords))

maze_tokens = [
Expand Down
8 changes: 5 additions & 3 deletions maze_transformer/evaluation/plot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from circuitsvis.tokens import colored_tokens_multi
from jaxtyping import Float
from maze_dataset import CoordTup, MazeDataset, MazeDatasetConfig, SolvedMaze
from maze_dataset.maze.lattice_maze import coord_str_to_tuple_noneable
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable

# Utilities
from muutils.json_serialize import SerializableDataclass, serializable_dataclass
Expand Down Expand Up @@ -52,12 +52,14 @@ def from_model_and_dataset(
for i in range(n_mazes):
# get the maze from the dataset and process into tokens
solved_maze: SolvedMaze = dataset[i]
tokens: list[str] = solved_maze.as_tokens(dataset.cfg.node_token_map)
tokens: list[str] = solved_maze.as_tokens(
model.zanj_model_config.maze_tokenizer
)
tokens_context: list[str]

if context_maze_only:
assert context_maze_fn is None
path_start_index: int = tokens.index(SPECIAL_TOKENS["path_start"])
path_start_index: int = tokens.index(SPECIAL_TOKENS.PATH_END)
tokens_context = tokens[: path_start_index + 1]
else:
assert context_maze_fn is not None
Expand Down
8 changes: 4 additions & 4 deletions maze_transformer/test_helpers/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jaxtyping import Int
from zanj.torchutil import ConfigMismatchException, assert_model_cfg_equality

from maze_transformer.training.config import BaseGPTConfig, ZanjHookedTransformer
from maze_transformer.training.config import ZanjHookedTransformer


def _check_except_config_equality_modulo_weight_processing(
Expand Down Expand Up @@ -44,11 +44,11 @@ def assert_model_output_equality(
raise e

# Random input tokens
dataset_cfg: BaseGPTConfig = model_a.zanj_model_config.dataset_cfg
tokenizer = model_a.zanj_model_config.tokenizer
input_sequence: Int[torch.Tensor, "1 test_sequence_length"] = torch.randint(
low=0,
high=len(dataset_cfg.token_arr),
size=(1, min(dataset_cfg.seq_len_max, test_sequence_length)),
high=len(tokenizer._token_arr),
size=(1, min(tokenizer._seq_len_max, test_sequence_length)),
)

# (copied from `test_eval_model.py`)
Expand Down
61 changes: 31 additions & 30 deletions maze_transformer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,66 @@
# Avoid circular import from training/config.py
from typing import TYPE_CHECKING, Union # need Union as "a" | "b" doesn't work
from typing import TYPE_CHECKING, Sequence # need Union as "a" | "b" doesn't work

import torch
from maze_dataset import SPECIAL_TOKENS, LatticeMaze
from maze_dataset.dataset.dataset import GPTDatasetConfig
from maze_dataset.plotting import MazePlot
from maze_dataset.tokenization import MazeTokenizer
from muutils.tensor_utils import ATensor, NDArray
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import BatchEncoding

if TYPE_CHECKING:
from maze_transformer.training.config import ConfigHolder
pass

# pylint: disable=unused-import, abstract-method


class HuggingMazeTokenizer(PreTrainedTokenizer):
"""extension of PreTrainedTokenizer for mazes"""

vocab: dict[str, int] # map of token_ids to strings

bos_token: str = SPECIAL_TOKENS["adj_list_start"]
eos_token: str = SPECIAL_TOKENS["path_end"]
pad_token: str = SPECIAL_TOKENS["padding"]
bos_token: str = SPECIAL_TOKENS.ADJLIST_START
eos_token: str = SPECIAL_TOKENS.PATH_END
pad_token: str = SPECIAL_TOKENS.PADDING
unk_token: str = "<UNK>"

vocab_size: int = 0
additional_special_tokens: list[str] = [
x for x in SPECIAL_TOKENS.values() if x not in [SPECIAL_TOKENS["padding"]]
x for x in SPECIAL_TOKENS.values() if x not in [SPECIAL_TOKENS.PADDING]
]

# Overwrite class attributes
padding_side = "left"
truncation_side = "left" #! strange choice, but it's what we did in pad_sequence

name_or_path = "maze_tokenizer"
name_or_path = "hugging_maze_tokenizer"

# TODO: this should just take seq_len_max and max grid n
def __init__(
self,
cfg: Union["ConfigHolder", "GPTDatasetConfig", None] = None,
token_arr: list[str] | None = None,
seq_len_max: int | None = None,
seq_len_max: int,
maze_tokenizer: MazeTokenizer,
**kwargs,
) -> None:
"""takes either a cfg, or a token_arr and seq_len_max. also, kwargs are passed to super `PreTrainedTokenizer`"""

if cfg is None:
assert token_arr is not None
assert seq_len_max is not None
else:
assert token_arr is None
assert seq_len_max is None
# Avoid isinstance() because of circular import
if type(cfg).__name__ == "ConfigHolder":
cfg = cfg.dataset_cfg
"""extension of PreTrainedTokenizer for mazes. takes maximum sequence length and maze_tokenizer. also, kwargs are passed to super `PreTrainedTokenizer`"""
super().__init__(max_len=seq_len_max, **kwargs)

seq_len_max = cfg.seq_len_max
token_arr = cfg.token_arr
self._maze_tokenizer: MazeTokenizer = maze_tokenizer
token_arr: list[str] = maze_tokenizer.token_arr
self._token_arr: list[str] = token_arr
self._seq_len_max: int = seq_len_max
self._vocab_size: int = maze_tokenizer.vocab_size
self.vocab_size = self._vocab_size
self._tokenizer_map = maze_tokenizer.tokenizer_map

assert isinstance(
seq_len_max, int
), f"seq_len_max must be an int, got {seq_len_max = } {type(seq_len_max) = }"
assert isinstance(
token_arr, Sequence
), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"
assert isinstance(
len(token_arr), int
), f"token_arr must be a Sequence, got {token_arr = } {type(token_arr) = }"

super().__init__(max_len=seq_len_max, **kwargs)
# We are having to do evil things here
vocab: dict[str, int] = {token: i for i, token in enumerate(token_arr)}
vocab[self.unk_token] = len(vocab)
Expand Down Expand Up @@ -90,7 +91,7 @@ def __call__(self, text, **kwargs) -> BatchEncoding:
raise NotImplementedError(
f"Caught an error during tokenization - probably because you are trying to encode a token not present in the tokenizer's vocabulary",
f"text: '{text}'",
)
) from e

def _tokenize(self, text: str, **kwargs) -> list[str]:
assert len(kwargs) == 0, f"kwargs not supported: {kwargs}"
Expand Down Expand Up @@ -130,5 +131,5 @@ def to_ascii(
sequence = sequence[sequence != self.pad_token_id]
str_sequence = self.batch_decode(sequence)

lattice_maze = LatticeMaze.from_tokens(str_sequence)
lattice_maze = LatticeMaze.from_tokens(str_sequence, self._maze_tokenizer)
return MazePlot(lattice_maze).to_ascii()
Loading

0 comments on commit 5452d13

Please sign in to comment.