From 4b63a7af79cffbcbdeaf8a19d60dead2fc710ba3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 6 Jan 2025 11:14:49 +0100 Subject: [PATCH] update serialization to support initializing from BPE --- .../models/backtracking_bpe/serialization.rs | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/backtracking_bpe/serialization.rs b/tokenizers/src/models/backtracking_bpe/serialization.rs index dabd5f748..f5d015b17 100644 --- a/tokenizers/src/models/backtracking_bpe/serialization.rs +++ b/tokenizers/src/models/backtracking_bpe/serialization.rs @@ -2,6 +2,7 @@ use super::{ super::bpe::Pair, super::OrderedVocabIter, convert_merges_to_hashmap, BacktrackingBpe, BacktrackingBpeBuilder, }; +use regex_syntax::ast::print; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, @@ -86,11 +87,11 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { builder = builder.unk_token(unk); } } - "vocab" => vocab = Some(map.next_value()?), + "vocab" => vocab = Some(map.next_value()?), "merges" => merges = Some(map.next_value()?), "type" => match map.next_value()? { "BacktrackingBpe" => {} - "BPE" => {} + "BPE" => {println!("Type is BPE but initializing a backtracking BPE")} u => { return Err(serde::de::Error::invalid_value( serde::de::Unexpected::Str(u), @@ -98,14 +99,18 @@ impl<'de> Visitor<'de> for BacktrackingBpeVisitor { )) } }, - _ => {} + field => { + println!("Ignoring unused field {:?}", field); // TODO make it into a logger + // Ensure the value is consumed to maintain valid deserialization + let _ = map.next_value::()?; + } } } if let (Some(vocab), Some(merges)) = (vocab, merges) { let merges = match merges { MergeType::Tuple(merges) => merges, MergeType::Legacy(merges) => { - convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)? + convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(|e| Error::custom("Error in convert merges to hashmap"))? } }; builder = builder.vocab_and_merges(vocab, merges); @@ -123,6 +128,29 @@ mod test { #[test] fn test_serialization() { + let bpe_string = r#"{ + "type": "BPE", + "dropout": null, + "unk_token": "", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "a": 1, + "b c d": 2, + "ab c d": 3 + }, + "merges": [ + ["a", "b c d"] + ] + }"#; + let reconstructed: Result = serde_json::from_str(&bpe_string); + println!("End of my example"); + + + let vocab: Vocab = [ ("a".into(), 1), ("b".into(), 2), @@ -137,6 +165,17 @@ mod test { .build() .unwrap(); + match reconstructed { + Ok(reconstructed) => { + println!("Good"); + assert_eq!(bpe, reconstructed); + } + Err(err) => { + println!("Error deserializing: {:?}", err); + + } + } + let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"","fuse_unk":false,"byte_fallback":false,"vocab":{"a":1,"b":2,"ab":3},"merges":["a b"]}"#; let legacy = serde_json::from_str(legacy); match legacy { @@ -180,5 +219,8 @@ mod test { ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); + + + } }