Skip to content

Commit

Permalink
push latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jan 3, 2025
1 parent d334fb4 commit a3ed0c3
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 46 deletions.
10 changes: 5 additions & 5 deletions tokenizers/benches/llama3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn llama3(c: &mut Criterion) {
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch(criterion::black_box(data.clone()), add_special_tokens)
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
Expand All @@ -52,22 +52,22 @@ pub fn llama3(c: &mut Criterion) {
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch(criterion::black_box(data.clone()), add_special_tokens)
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});

group.bench_function("llama3-offsets", |b| {
group.bench_function("llama3-encode_batch_fast", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch_char_offsets(criterion::black_box(data.clone()), add_special_tokens)
.encode_batch_fast(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});
group.bench_function("llama3-nooffsets", |b| {
group.bench_function("llama3-encode_batch", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
Expand Down
1 change: 1 addition & 0 deletions tokenizers/src/models/backtracking_bpe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod bitfield;
mod model;
mod serialization;

pub use model::*;
95 changes: 57 additions & 38 deletions tokenizers/src/models/backtracking_bpe/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::bitfield::BitField;
use super::{super::bpe::trainer::BpeTrainer, super::bpe::Error, super::OrderedVocabIter};
use crate::models::bpe::BPE;
use crate::models::bpe::{MergeMap, Pair, BPE};
use crate::tokenizer::{Model, Result, Token};
use crate::utils::iter::ResultShunt;
use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
Expand All @@ -24,6 +24,7 @@ pub type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;
pub type Merges = Vec<(String, String)>;


/// This can be thought of as a lazy variation of the dynamic programming approach.
/// It only computes those states which have to be visited in order to compute the tokenization
/// for a given input text.
Expand Down Expand Up @@ -102,8 +103,9 @@ impl Default for BacktrackingBpeBuilder {
}
}


/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
#[derive(Serialize, PartialEq, Clone)]
#[derive(PartialEq, Clone)]
pub struct BacktrackingBpe {
/// All the decoded tokens concatenated into? used to build the aho corasick searchers
all_tokens: Vec<u8>,
Expand All @@ -122,21 +124,18 @@ pub struct BacktrackingBpe {
// serialize_with = "serialize_daac",
// deserialize_with = "deserialize_daac"
// )]
#[serde(skip)]
longest_searcher: DoubleArrayAhoCorasick<u32>,
/// An aho corasick automaton to find ALL tokens in a byte sequence.
// #[serde(
// serialize_with = "serialize_daac",
// deserialize_with = "deserialize_daac"
// )]
#[serde(skip)]
pub(crate) overlapping_searcher: DoubleArrayAhoCorasick<u32>,
/// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order.
// #[serde(
// serialize_with = "serialize_daac",
// deserialize_with = "deserialize_daac"
// )]
#[serde(skip)]
pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick<u32>,
/// Mapping from a token to the next longest prefix token.
/// This is in principle information represented by the AhoCorasick automaton.
Expand All @@ -145,12 +144,14 @@ pub struct BacktrackingBpe {
next_prefix_match: Vec<u32>,
/// Hash factor used to prevent hash collisions.
hash_factor: u64,
vocab: Vocab,
vocab_r: VocabR,
pub vocab: Vocab,
pub vocab_r: VocabR,
unk_token: Option<String>,
pub merges: MergeMap,
}

use std::fmt;

// Manually implement the Debug trait to exclude the `cache` field
impl fmt::Debug for BacktrackingBpe {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -430,7 +431,7 @@ impl BacktrackingBpe {
/// to prevent repeating the cost of computing the hash factor and encoding.
pub fn from_dictionary(
tokens: impl IntoIterator<Item = Vec<u8>>,
merges: Option<Vec<(String, String)>>,
merges: Option<Merges>,
hash_factor: Option<u64>,
) -> Self {
let hash_factor = hash_factor
Expand All @@ -440,6 +441,7 @@ impl BacktrackingBpe {
let mut all_tokens_rev = Vec::new();
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
let mut merge_map :HashMap<Pair, (u32, u32)> = HashMap::new();
for (i, token) in tokens.into_iter().enumerate() {
bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32);
all_tokens_rev.extend(token.iter().copied().rev());
Expand All @@ -463,36 +465,6 @@ impl BacktrackingBpe {
})
.collect();

let mut split_table = vec![];
let mut pair_lookup = FnvHashMap::default();

// Reverse engineer the merge/split table.
for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() {
let mut token1 = next_prefix_match[id];
while token1 != u32::MAX {
let rest = &token[token_range(&token_starts, token1).len()..];
if let Some(token2) = find_token_by_bytes(
&all_tokens,
&token_starts,
&bytes_hash_to_token,
rest,
hash_factor,
) {
if token1 < id as u32
&& token2 < id as u32
&& is_valid_token_pair(&pair_lookup, &split_table, token1, token2)
{
pair_lookup.insert((token1, token2), id as u32);
split_table.push((token1, token2));
break;
}
}
token1 = next_prefix_match[token1 as usize];
}
if token1 == u32::MAX {
split_table.push((id as u32, id as u32));
}
}
let vocab: HashMap<String, u32> = token_iter(&all_tokens, &token_starts)
.enumerate()
.map(|(id, item)| {
Expand All @@ -512,6 +484,52 @@ impl BacktrackingBpe {
})
.collect();

let mut split_table = vec![];
let mut pair_lookup = FnvHashMap::default();

if let Some(ref merges) = merges {
for (id, pair) in merges.into_iter().enumerate(){
let token1 = vocab[&pair.0.clone()];
let token2 = vocab[&pair.1.clone()];
pair_lookup.insert((token1, token2), id as u32);
split_table.push((token1, token2));
merge_map.insert(Pair::from(pair), (id as u32, id as u32)); // TODO wrong
};
} else {
// Reverse engineer the merge/split table.
for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() {
let mut token1 = next_prefix_match[id];
while token1 != u32::MAX {
let rest = &token[token_range(&token_starts, token1).len()..];
if let Some(token2) = find_token_by_bytes(
&all_tokens,
&token_starts,
&bytes_hash_to_token,
rest,
hash_factor,
) {
if token1 < id as u32
&& token2 < id as u32
&& is_valid_token_pair(&pair_lookup, &split_table, token1, token2)
{
pair_lookup.insert((token1, token2), id as u32);
split_table.push((token1, token2));
let str_token1 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token1)]))};
let str_token2 = unsafe { String::from_utf8_unchecked(Vec::from(&all_tokens[token_range(&token_starts, token2)]))};
merge_map.insert(Pair::from(&(str_token1,str_token2)), (id as u32, id as u32)); // TODO wrong
break;
}
}
token1 = next_prefix_match[token1 as usize];
}
if token1 == u32::MAX {
split_table.push((id as u32, id as u32));
}
}
};



let bpe = Self {
all_tokens,
token_starts,
Expand All @@ -526,6 +544,7 @@ impl BacktrackingBpe {
unk_token: None,
vocab,
vocab_r,
merges: merge_map
};
for token_id in 0..bpe.num_tokens() as u32 {
let bytes = bpe.token_bytes(token_id);
Expand Down
Loading

0 comments on commit a3ed0c3

Please sign in to comment.