Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the stream by removing the read_index altogether. #1716

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bindings/python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ style:
# Check the source code is formatted correctly
check-style:
python stub.py --check
ruff check examples py_src/tokenizers tests
ruff format --check examples py_src/tokenizers tests
ruff check $(check_dirs)
ruff format --check $(check_dirs)

TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json

Expand Down
2 changes: 1 addition & 1 deletion bindings/python/py_src/tokenizers/tools/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def consecutive_chars_to_html(
# In this case we are looking at a group/single char that is not tokenized.
# e.g. white space
css_classes.append("non-token")
css = f'''class="{' '.join(css_classes)}"'''
css = f'''class="{" ".join(css_classes)}"'''
data = ""
for key, val in data_items.items():
data += f' data-{key}="{val}"'
Expand Down
7 changes: 0 additions & 7 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,11 +646,6 @@ pub struct PyDecodeStream {
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[pymethods]
Expand All @@ -663,7 +658,6 @@ impl PyDecodeStream {
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

Expand All @@ -676,7 +670,6 @@ impl PyDecodeStream {
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
))
.into()
}
Expand Down
6 changes: 5 additions & 1 deletion tokenizers/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TESTS_DIR = tests

dir_guard=@mkdir -p $(@D)

SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json

Expand Down Expand Up @@ -79,3 +79,7 @@ $(DATA_DIR)/tokenizer-wiki.json :
$(DATA_DIR)/bert-wiki.json :
$(dir_guard)
wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@

$(DATA_DIR)/llama-3-tokenizer.json :
$(dir_guard)
wget https://huggingface.co/hf-internal-testing/llama3-tokenizer/resolve/main/tokenizer.json -O $@
4 changes: 2 additions & 2 deletions tokenizers/src/models/bpe/word.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ impl Word {

// Make sure we are not processing an expired queue entry
let target_new_pair = (self.symbols[top.pos].c, right.c);
if !merges
if merges
.get(&target_new_pair)
.map_or(false, |(_, new_id)| *new_id == top.new_id)
.is_none_or(|(_, new_id)| *new_id != top.new_id)
{
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/processors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ impl TemplateProcessingBuilder {
let exist = self
.special_tokens
.as_ref()
.map_or(false, |map| map.0.contains_key(sp));
.is_some_and(|map| map.0.contains_key(sp));

match exist {
false => Some(sp),
Expand Down
118 changes: 0 additions & 118 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,11 +1035,6 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}

#[derive(thiserror::Error, Debug)]
Expand All @@ -1063,7 +1058,6 @@ where
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}

Expand All @@ -1076,7 +1070,6 @@ where
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
)
}
}
Expand All @@ -1089,7 +1082,6 @@ pub fn step_decode_stream<M, N, PT, PP, D>(
ids: &mut Vec<u32>,
prefix: &mut String,
prefix_index: &mut usize,
read_index: &mut usize,
) -> Result<Option<String>>
where
M: Model,
Expand All @@ -1108,7 +1100,6 @@ where
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*prefix_index..).collect();
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Expand Down Expand Up @@ -1563,112 +1554,3 @@ where
Ok(())
}
}

#[cfg(test)]
mod test {
#[cfg(feature = "http")]
#[test]
fn test_decoding_with_added_bpe() {
use crate::{
normalizers,
pre_tokenizers::split::{Split, SplitPattern},
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
};

let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
Split::new(
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
SplitDelimiterBehavior::Isolated,
false,
)
.unwrap(),
));
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
let encoded = tokenizer
.encode("Hey! how is this token: 嗎", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
);

let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");

tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
let encoded = tokenizer
.encode("Hey! how is this token: д", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
}

#[cfg(feature = "http")]
#[test]
fn test_decode_stream_step_no_panic() {
use std::panic;

use crate::Tokenizer;

let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();

// "A B C D E F G H I J"
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
let expected_outputs = vec![
Some("A".to_string()),
Some(" B".to_string()),
Some(" C".to_string()),
Some(" D".to_string()),
Some(" E".to_string()),
Some(" F".to_string()),
Some(" G".to_string()),
Some(" H".to_string()),
Some(" I".to_string()),
Some(" J".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}

// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
let expected_outputs = vec![
None,
Some("삥".to_string()),
None,
None,
Some("뽕".to_string()),
None,
Some("빵".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}
}
}
78 changes: 78 additions & 0 deletions tokenizers/tests/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use tokenizers::{
normalizers,
pre_tokenizers::split::{Split, SplitPattern},
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
};

#[test]
fn test_decoding_with_added_bpe() {
let mut tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
tokenizer.with_normalizer(Some(NormalizerWrapper::from(normalizers::ByteLevel::new())));
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Split(
Split::new(
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
SplitDelimiterBehavior::Isolated,
false,
)
.unwrap(),
)));
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
let encoded = tokenizer
.encode("Hey! how is this token: 嗎", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
);

let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");

tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
let encoded = tokenizer
.encode("Hey! how is this token: д", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
}

#[test]
fn test_decode_stream_step_no_panic() {
let tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();

// "A B C D E F G H I J"
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string()));
assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string()));
assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string()));
assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string()));
assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string()));
assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string()));
assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string()));
assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string()));
assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string()));
assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string()));
// for (i, &token) in output_tokens.iter().enumerate() {}

// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(80690).unwrap(), None);
assert_eq!(decode_stream.step(98).unwrap(), Some("삥".to_string()));
assert_eq!(decode_stream.step(167).unwrap(), None);
assert_eq!(decode_stream.step(121).unwrap(), None);
assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".to_string()));
assert_eq!(decode_stream.step(102457).unwrap(), None);
assert_eq!(decode_stream.step(113).unwrap(), Some("빵".to_string()));
}
Loading