Skip to content

Commit

Permalink
Add tests for tokenization methods (TransformerLensOrg#280)
Browse files Browse the repository at this point in the history
* simplify tokenizer initialization

* no need for "jank token setup" when using manual set_tokenizer(), use the method from init and move the special tokenizer init code to the method

* add tests for tokenization methods (TransformerLensOrg#100)

* add a README section about formatting

* vscode format on save is already using black, but not isort, run `make format` manually to fix the imports

* fix type errors
  • Loading branch information
Aprillion authored May 22, 2023
1 parent c268a71 commit 2c0eea5
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 28 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ If adding a feature, please add unit tests for it to the tests folder, and check
- Unit tests only via `make unit-test`
- Acceptance tests only via `make acceptance-test`

### Formatting

This project uses `pycln`, `isort` and `black` for formatting, pull requests are checked in github actions.

- Format all files via `make format`
- Only check the formatting via `make check-format`

### Demos

If adding a feature, please add it to the demo notebook in the `demos` folder, and check that it works in the demo format. This can be tested by replacing `pip install git+https://github.com/JayBaileyCS/TransformerLens.git` with `pip install git+https://github.com/<YOUR_USERNAME_HERE>/TransformerLens.git` in the demo notebook, and running it in a fresh environment.
Expand Down
8 changes: 0 additions & 8 deletions tests/acceptance/test_tokenizer_special_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ def test_d_vocab_from_tokenizer():
model = HookedTransformer(
cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name)
)
# Jank token setup
# Perhaps we should write a wrapper around the tokenizer
if model.tokenizer.eos_token is None:
model.tokenizer.eos_token = "<|endoftext|>"
if model.tokenizer.pad_token is None:
model.tokenizer.pad_token = model.tokenizer.eos_token
if model.tokenizer.bos_token is None:
model.tokenizer.bos_token = model.tokenizer.eos_token

tokens_with_bos = model.to_tokens(test_string)
tokens_without_bos = model.to_tokens(test_string, prepend_bos=False)
Expand Down
158 changes: 158 additions & 0 deletions tests/unit/test_tokenization_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import pytest
from torch import Size, equal, tensor
from transformers import AutoTokenizer

from transformer_lens import HookedTransformer, HookedTransformerConfig

model = HookedTransformer.from_pretrained("solu-1l")


def test_set_tokenizer_during_initialization():
assert (
model.tokenizer is not None
and model.tokenizer.name_or_path == "NeelNanda/gpt-neox-tokenizer-digits"
), "initialized with expected tokenizer"
assert model.cfg.d_vocab == 48262, "expected d_vocab"


def test_set_tokenizer_lazy():
cfg = HookedTransformerConfig(1, 10, 1, 1, act_fn="relu", d_vocab=50256)
model2 = HookedTransformer(cfg)
original_tokenizer = model2.tokenizer
assert original_tokenizer is None, "initialize without tokenizer"
model2.set_tokenizer(AutoTokenizer.from_pretrained("gpt2"))
tokenizer = model2.tokenizer
assert tokenizer is not None and tokenizer.name_or_path == "gpt2", "set tokenizer"
assert model2.to_single_token(" SolidGoldMagicarp") == 15831, "glitch token"


def test_to_tokens_default():
s = "Hello, world!"
tokens = model.to_tokens(s)
assert equal(
tokens, tensor([[1, 11765, 14, 1499, 3]])
), "creates a tensor of tokens with BOS"


def test_to_tokens_without_bos():
s = "Hello, world!"
tokens = model.to_tokens(s, prepend_bos=False)
assert equal(tokens, tensor([[11765, 14, 1499, 3]])), "creates a tensor without BOS"


def test_to_tokens_device():
s = "Hello, world!"
tokens1 = model.to_tokens(s, move_to_device=False)
tokens2 = model.to_tokens(s, move_to_device=True)
assert equal(
tokens1, tokens2
), "move to device has no effect when running tests on CPU"


def test_to_tokens_truncate():
assert model.cfg.n_ctx == 1024, "verify assumed context length"
s = "@ " * 1025
tokens1 = model.to_tokens(s)
tokens2 = model.to_tokens(s, truncate=False)
assert len(tokens1[0]) == 1024, "truncated by default"
assert len(tokens2[0]) == 1027, "not truncated"


def test_to_string_from_to_tokens_without_bos():
s = "Hello, world!"
tokens = model.to_tokens(s, prepend_bos=False)
s2 = model.to_string(tokens[0])
assert s == s2, "same string when converted back to string"


def test_to_string_multiple():
s_list = model.to_string(tensor([[1, 11765], [43453, 28666]]))
assert s_list == [
"<|BOS|>Hello",
"Charlie Planet",
], "can handle list of lists"


def test_to_str_tokens_default():
s_list = model.to_str_tokens(" SolidGoldMagikarp")
assert s_list == [
"<|BOS|>",
" Solid",
"Gold",
"Mag",
"ik",
"arp",
], "not a glitch token"


def test_to_str_tokens_without_bos():
s_list = model.to_str_tokens(" SolidGoldMagikarp", prepend_bos=False)
assert s_list == [
" Solid",
"Gold",
"Mag",
"ik",
"arp",
], "without BOS"


def test_to_single_token():
token = model.to_single_token("biomolecules")
assert token == 31847, "single token"


def test_to_single_str_tokent():
s = model.to_single_str_token(31847)
assert s == "biomolecules"


def test_get_token_position_not_found():
single = "biomolecules"
input = "There were some biomolecules"
with pytest.raises(AssertionError) as exc_info:
model.get_token_position(single, input)
assert (
str(exc_info.value) == "The token does not occur in the prompt"
), "assertion error"


def test_get_token_position_str():
single = " some"
input = "There were some biomolecules"
pos = model.get_token_position(single, input)
assert pos == 3, "first position"


def test_get_token_position_str_without_bos():
single = " some"
input = "There were some biomolecules"
pos = model.get_token_position(single, input, prepend_bos=False)
assert pos == 2, "without BOS"


def test_get_token_position_int_pos():
single = 2
input = tensor([2.0, 3, 4])
pos1 = model.get_token_position(single, input)
pos2 = model.get_token_position(single, input, prepend_bos=False)
assert pos1 == 0, "first position"
assert pos2 == 0, "no effect from BOS when using tensor as input"


def test_get_token_position_int_pos_last():
single = 2
input = tensor([2.0, 3, 4, 2, 5])
pos1 = model.get_token_position(single, input, mode="last")
assert pos1 == 3, "last position"


def test_get_token_position_int_1_pos():
single = 2
input = tensor([[2.0, 3, 4]])
pos = model.get_token_position(single, input)
assert pos == 0, "first position"


def test_tokens_to_residual_directions():
res_dir = model.tokens_to_residual_directions(model.to_tokens(""))
assert res_dir.shape == Size([512]), ""
44 changes: 24 additions & 20 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tqdm.auto as tqdm
from fancy_einsum import einsum
from jaxtyping import Float, Int
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typeguard import typeguard_ignore
from typing_extensions import Literal

Expand Down Expand Up @@ -89,33 +89,23 @@ def __init__(
), "If n_devices > 1, must move_to_device"

if tokenizer is not None:
self.tokenizer = tokenizer
self.set_tokenizer(tokenizer)
elif self.cfg.tokenizer_name is not None:
# If we have a tokenizer name, we can load it from HuggingFace
if "llama" in self.cfg.tokenizer_name:
# llama tokenizer requires special handling
print("Warning: LLaMA tokenizer not loaded. Please load manually.")
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "<|endoftext|>"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.bos_token is None:
self.tokenizer.bos_token = self.tokenizer.eos_token
self.set_tokenizer(
AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
)
else:
# If no tokenizer name is provided, we assume we're training on an algorithmic task and will pass in tokens
# directly. In this case, we don't need a tokenizer.
self.tokenizer = None

if self.cfg.d_vocab == -1:
# If we have a tokenizer, vocab size can be inferred from it.
assert (
self.tokenizer is not None
self.cfg.d_vocab != -1
), "Must provide a tokenizer if d_vocab is not provided"
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab
self.tokenizer = None

self.embed = Embed(self.cfg)
self.hook_embed = HookPoint() # [batch, pos, d_model]
Expand Down Expand Up @@ -446,9 +436,23 @@ def set_tokenizer(self, tokenizer):
Sets the tokenizer to use for this model.
tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer
"""
assert isinstance(tokenizer, PreTrainedTokenizer)
assert isinstance(
tokenizer, PreTrainedTokenizerBase
), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token

if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "<|endoftext|>"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.bos_token is None:
self.tokenizer.bos_token = self.tokenizer.eos_token

# Infer vocab size from tokenizer
if self.cfg.d_vocab == -1:
self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab

def to_tokens(
self,
Expand Down Expand Up @@ -768,7 +772,7 @@ def from_pretrained(
n_devices=1,
move_state_dict_to_device=True,
**model_kwargs,
):
) -> "HookedTransformer":
"""Class method to load in a pretrained model weights to the HookedTransformer format and optionally to do some
processing to make the model easier to interpret. Currently supports loading from most autoregressive
HuggingFace models (GPT2, GPTNeo, GPTJ, OPT) and from a range of toy models and SoLU models trained by me (Neel Nanda).
Expand Down

0 comments on commit 2c0eea5

Please sign in to comment.