From a3ed0c3e8d08a1ebcf33736a58e70dabb09530ba Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 17:28:42 +0100 Subject: [PATCH] push latest changes --- tokenizers/benches/llama3.rs | 10 +- tokenizers/src/models/backtracking_bpe/mod.rs | 1 + .../src/models/backtracking_bpe/model.rs | 95 +++++---- .../models/backtracking_bpe/serialization.rs | 180 ++++++++++++++++++ tokenizers/src/models/bpe/mod.rs | 2 +- tokenizers/src/models/mod.rs | 4 +- 6 files changed, 246 insertions(+), 46 deletions(-) create mode 100644 tokenizers/src/models/backtracking_bpe/serialization.rs diff --git a/tokenizers/benches/llama3.rs b/tokenizers/benches/llama3.rs index f327c0dbb..05efcbd62 100644 --- a/tokenizers/benches/llama3.rs +++ b/tokenizers/benches/llama3.rs @@ -29,7 +29,7 @@ pub fn llama3(c: &mut Criterion) { let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); @@ -52,22 +52,22 @@ pub fn llama3(c: &mut Criterion) { let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-offsets", |b| { + group.bench_function("llama3-encode_batch_fast", |b| { let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; b.iter(|| { tokenizer - .encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens) + .encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens) .unwrap() }) }); - group.bench_function("llama3-nooffsets", |b| { + group.bench_function("llama3-encode_batch", |b| { let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap(); let data: Vec<_> = data.lines().collect(); let add_special_tokens = false; diff --git a/tokenizers/src/models/backtracking_bpe/mod.rs b/tokenizers/src/models/backtracking_bpe/mod.rs index 0e2aece93..2c17cbf46 100644 --- a/tokenizers/src/models/backtracking_bpe/mod.rs +++ b/tokenizers/src/models/backtracking_bpe/mod.rs @@ -1,4 +1,5 @@ mod bitfield; mod model; +mod serialization; pub use model::*; diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index c9d813e2e..2a13c4ac2 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -1,6 +1,6 @@ use super::bitfield::BitField; use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter}; -use crate::models::bpe::BPE; +use crate::models::bpe::{MergeMap, Pair, BPE}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::iter::ResultShunt; use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder}; @@ -24,6 +24,7 @@ pub type Vocab = HashMap; type VocabR = HashMap; pub type Merges = Vec<(String, String)>; + /// This can be thought of as a lazy variation of the dynamic programming approach. /// It only computes those states which have to be visited in order to compute the tokenization /// for a given input text. @@ -102,8 +103,9 @@ impl Default for BacktrackingBpeBuilder { } } + /// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. -#[derive(Serialize, PartialEq, Clone)] +#[derive(PartialEq, Clone)] pub struct BacktrackingBpe { /// All the decoded tokens concatenated into? used to build the aho corasick searchers all_tokens: Vec, @@ -122,21 +124,18 @@ pub struct BacktrackingBpe { // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] longest_searcher: DoubleArrayAhoCorasick, /// An aho corasick automaton to find ALL tokens in a byte sequence. // #[serde( // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. // #[serde( // serialize_with = "serialize_daac", // deserialize_with = "deserialize_daac" // )] - #[serde(skip)] pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, /// Mapping from a token to the next longest prefix token. /// This is in principle information represented by the AhoCorasick automaton. @@ -145,12 +144,14 @@ pub struct BacktrackingBpe { next_prefix_match: Vec, /// Hash factor used to prevent hash collisions. hash_factor: u64, - vocab: Vocab, - vocab_r: VocabR, + pub vocab: Vocab, + pub vocab_r: VocabR, unk_token: Option, + pub merges: MergeMap, } use std::fmt; + // Manually implement the Debug trait to exclude the `cache` field impl fmt::Debug for BacktrackingBpe { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -430,7 +431,7 @@ impl BacktrackingBpe { /// to prevent repeating the cost of computing the hash factor and encoding. pub fn from_dictionary( tokens: impl IntoIterator>, - merges: Option>, + merges: Option, hash_factor: Option, ) -> Self { let hash_factor = hash_factor @@ -440,6 +441,7 @@ impl BacktrackingBpe { let mut all_tokens_rev = Vec::new(); let mut token_starts = vec![0]; let mut bytes_hash_to_token = FnvHashMap::default(); + let mut merge_map :HashMap = HashMap::new(); for (i, token) in tokens.into_iter().enumerate() { bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); @@ -463,36 +465,6 @@ impl BacktrackingBpe { }) .collect(); - let mut split_table = vec![]; - let mut pair_lookup = FnvHashMap::default(); - - // Reverse engineer the merge/split table. - for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { - let mut token1 = next_prefix_match[id]; - while token1 != u32::MAX { - let rest = &token[token_range(&token_starts, token1).len()..]; - if let Some(token2) = find_token_by_bytes( - &all_tokens, - &token_starts, - &bytes_hash_to_token, - rest, - hash_factor, - ) { - if token1 < id as u32 - && token2 < id as u32 - && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) - { - pair_lookup.insert((token1, token2), id as u32); - split_table.push((token1, token2)); - break; - } - } - token1 = next_prefix_match[token1 as usize]; - } - if token1 == u32::MAX { - split_table.push((id as u32, id as u32)); - } - } let vocab: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() .map(|(id, item)| { @@ -512,6 +484,52 @@ impl BacktrackingBpe { }) .collect(); + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + + if let Some(ref merges) = merges { + for (id, pair) in merges.into_iter().enumerate(){ + let token1 = vocab[&pair.0.clone()]; + let token2 = vocab[&pair.1.clone()]; + pair_lookup.insert((token1, token2), id as u32); + split_table.push((token1, token2)); + merge_map.insert(Pair::from(pair), (id as u32, id as u32)); // TODO wrong + }; + } else { + // Reverse engineer the merge/split table. + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut token1 = next_prefix_match[id]; + while token1 != u32::MAX { + let rest = &token[token_range(&token_starts, token1).len()..]; + if let Some(token2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if token1 < id as u32 + && token2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, token1, token2) + { + pair_lookup.insert((token1, token2), id as u32); + split_table.push((token1, token2)); + let str_token1 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token1)]))}; + let str_token2 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token2)]))}; + merge_map.insert(Pair::from(&(str_token1,str_token2)), (id as u32, id as u32)); // TODO wrong + break; + } + } + token1 = next_prefix_match[token1 as usize]; + } + if token1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + }; + + + let bpe = Self { all_tokens, token_starts, @@ -526,6 +544,7 @@ impl BacktrackingBpe { unk_token: None, vocab, vocab_r, + merges: merge_map }; for token_id in 0..bpe.num_tokens() as u32 { let bytes = bpe.token_bytes(token_id); diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs new file mode 100644 index 000000000..3218faa38 --- /dev/null +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -0,0 +1,180 @@ +use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, BacktrackingBpeBuilder, super::bpe::Pair }; +use serde::{ + de::{Error, MapAccess, Visitor}, + ser::SerializeStruct, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::collections::HashMap; + +impl Serialize for BacktrackingBpe { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut model = serializer.serialize_struct("BPE", 8)?; + + // Start by small fields + model.serialize_field("type", "BPE")?; + + // Then the large ones + let mut merges: Vec<(&Pair, &u32)> = self + .merges + .iter() + .map(|(pair, (rank, _))| (pair, rank)) + .collect(); + merges.sort_unstable_by_key(|k| *k.1); + let merges = merges + .into_iter() + .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) + .collect::>(); + let ordered_vocab = OrderedVocabIter::new(&self.vocab_r); + + model.serialize_field("vocab", &ordered_vocab)?; + model.serialize_field("merges", &merges)?; + + model.end() + } +} + +impl<'de> Deserialize<'de> for BacktrackingBpe { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "BPE", + &[ + "type", + "dropout", + "unk_token", + "vocab", + "merges", + ], + BPEVisitor, + ) + } +} + +struct BPEVisitor; +impl<'de> Visitor<'de> for BPEVisitor { + type Value = BacktrackingBpe; + + fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(fmt, "struct BPE") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: MapAccess<'de>, + { + let mut builder =BacktrackingBpeBuilder::new(); + let mut vocab: Option> = None; + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + enum MergeType { + Tuple(Vec<(String, String)>), + Legacy(Vec), + } + let mut merges: Option = None; + while let Some(key) = map.next_key::()? { + match key.as_ref() { + "dropout" => { + if let Some(dropout) = map.next_value()? { + builder = builder.dropout(dropout); + } + } + "unk_token" => { + if let Some(unk) = map.next_value()? { + builder = builder.unk_token(unk); + } + } + "vocab" => vocab = Some(map.next_value()?), + "merges" => merges = Some(map.next_value()?), + "type" => match map.next_value()? { + "BPE" => {} + u => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(u), + &"BacktrackingBpe", + )) + } + }, + _ => {} + } + } + if let (Some(vocab), Some(merges)) = (vocab, merges) { + let merges = match merges { + MergeType::Tuple(merges) => merges, + MergeType::Legacy(merges) => { + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + } + }; + builder = builder.vocab_and_merges(vocab, merges); + Ok(builder.build().map_err(Error::custom)?) + } else { + Err(Error::custom("Missing vocab/merges")) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::models::bpe::Vocab; + + #[test] + fn test_serialization() { + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b".into(), 2), + ("ab".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#; + let legacy = serde_json::from_str(legacy).unwrap(); + assert_eq!(bpe, legacy); + + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + + // With a space in the token + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b c d".into(), 2), + ("ab c d".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe =BacktrackingBpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())]) + .unk_token("".to_string()) + .build() + .unwrap(); + let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + assert_eq!(bpe, reconstructed); + } + + +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df..6bd51cb1b 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -6,7 +6,7 @@ mod serialization; pub mod trainer; mod word; -type Pair = (u32, u32); +pub(crate) type Pair = (u32, u32); /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index ecf9a5423..da8adddcf 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -101,7 +101,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { #[derive(Deserialize)] #[serde(untagged)] pub enum ModelUntagged { - BPE(BPE), + BPE(BacktrackingBpe), // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility // with the versions not including the "type"), since WordLevel is a subset of WordPiece WordPiece(WordPiece), @@ -128,7 +128,7 @@ impl<'de> Deserialize<'de> for ModelWrapper { ModelHelper::Legacy(value) => { let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?; match untagged { - ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe), + ModelUntagged::BPE(bpe) => ModelWrapper::BacktrackingBpe(bpe), ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe), ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe), ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),