Skip to content

Commit

Permalink
update serialization to support initializing from BPE
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jan 6, 2025
1 parent ee18ba9 commit 4b63a7a
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions tokenizers/src/models/backtracking_bpe/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{
super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe,
BacktrackingBpeBuilder,
};
use regex_syntax::ast::print;
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeStruct,
Expand Down Expand Up @@ -86,26 +87,30 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor {
builder = builder.unk_token(unk);
}
}
"vocab" => vocab = Some(map.next_value()?),
"vocab" => vocab = Some(map.next_value()?),
"merges" => merges = Some(map.next_value()?),
"type" => match map.next_value()? {
"BacktrackingBpe" => {}
"BPE" => {}
"BPE" => {println!("Type is BPE but initializing a backtracking BPE")}
u => {
return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(u),
&"BacktrackingBpe should have been found",
))
}
},
_ => {}
field => {
println!("Ignoring unused field {:?}", field); // TODO make it into a logger
// Ensure the value is consumed to maintain valid deserialization
let _ = map.next_value::<serde::de::IgnoredAny>()?;
}
}
}
if let (Some(vocab), Some(merges)) = (vocab, merges) {
let merges = match merges {
MergeType::Tuple(merges) => merges,
MergeType::Legacy(merges) => {
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(|e| Error::custom("Error in convert merges to hashmap"))?
}
};
builder = builder.vocab_and_merges(vocab, merges);
Expand All @@ -123,6 +128,29 @@ mod test {

#[test]
fn test_serialization() {
let bpe_string = r#"{
"type": "BPE",
"dropout": null,
"unk_token": "<unk>",
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": true,
"vocab": {
"a": 1,
"b c d": 2,
"ab c d": 3
},
"merges": [
["a", "b c d"]
]
}"#;
let reconstructed: Result<BacktrackingBpe, serde_json::Error> = serde_json::from_str(&bpe_string);
println!("End of my example");



let vocab: Vocab = [
("a".into(), 1),
("b".into(), 2),
Expand All @@ -137,6 +165,17 @@ mod test {
.build()
.unwrap();

match reconstructed {
Ok(reconstructed) => {
println!("Good");
assert_eq!(bpe, reconstructed);
}
Err(err) => {
println!("Error deserializing: {:?}", err);

}
}

let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
let legacy = serde_json::from_str(legacy);
match legacy {
Expand Down Expand Up @@ -180,5 +219,8 @@ mod test {
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed);



}
}

0 comments on commit 4b63a7a

Please sign in to comment.