From e42ac9e6a4c701424be91323fe39d998f00adfc8 Mon Sep 17 00:00:00 2001 From: jonvet Date: Sun, 5 Jan 2025 22:46:46 +0000 Subject: [PATCH] Fixed Length PRe-Tokenizer --- .../tokenizers/pre_tokenizers/__init__.py | 1 + .../tokenizers/pre_tokenizers/__init__.pyi | 51 ++++++++ bindings/python/src/pre_tokenizers.rs | 37 ++++++ .../tests/bindings/test_pre_tokenizers.py | 16 +++ tokenizers/src/pre_tokenizers/fixed_length.rs | 116 ++++++++++++++++++ tokenizers/src/pre_tokenizers/mod.rs | 13 ++ 6 files changed, 234 insertions(+) create mode 100644 tokenizers/src/pre_tokenizers/fixed_length.rs diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py index 48277f0d2..db8ddc208 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py @@ -6,6 +6,7 @@ ByteLevel = pre_tokenizers.ByteLevel CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit Digits = pre_tokenizers.Digits +FixedLength = pre_tokenizers.FixedLength Metaspace = pre_tokenizers.Metaspace Punctuation = pre_tokenizers.Punctuation Sequence = pre_tokenizers.Sequence diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi index a583945fc..cf21df28b 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi @@ -258,6 +258,57 @@ class Digits(PreTokenizer): """ pass +class FixedLength(PreTokenizer): + """ + This pre-tokenizer splits the text into fixed length chunks + + Args: + length (:obj:`int`, `optional`, defaults to :obj:`5`): + The length of the chunks to split the text into. + + Strings are split on the character level rather than the byte level to avoid + splitting unicode characters consisting of multiple bytes. + """ + def __init__(self, length=5): + pass + + def pre_tokenize(self, pretok): + """ + Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place + + This method allows to modify a :class:`~tokenizers.PreTokenizedString` to + keep track of the pre-tokenization, and leverage the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you just want to see the result of + the pre-tokenization of a raw string, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str` + + Args: + pretok (:class:`~tokenizers.PreTokenizedString): + The pre-tokenized string on which to apply this + :class:`~tokenizers.pre_tokenizers.PreTokenizer` + """ + pass + + def pre_tokenize_str(self, sequence): + """ + Pre tokenize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the + alignment, nor does it provide all the capabilities of the + :class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use + :meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize` + + Args: + sequence (:obj:`str`): + A string to pre-tokeize + + Returns: + :obj:`List[Tuple[str, Offsets]]`: + A list of tuple with the pre-tokenized parts and their offsets + """ + pass + class Metaspace(PreTokenizer): """ Metaspace pre-tokenizer diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index fdc862302..09848be49 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -11,6 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer; use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::digits::Digits; +use tk::pre_tokenizers::fixed_length::FixedLength; use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; use tk::pre_tokenizers::punctuation::Punctuation; use tk::pre_tokenizers::split::Split; @@ -113,6 +114,12 @@ impl PyPreTokenizer { .into_any() .into() } + PreTokenizerWrapper::FixedLength(_) => { + Py::new(py, (PyFixedLength {}, base))? + .into_pyobject(py)? + .into_any() + .into() + } }, } } @@ -627,6 +634,35 @@ impl PyDigits { } } +/// This pre-tokenizer splits the text into fixed length chunks +/// +/// Args: +/// length (:obj:`int`, `optional`, defaults to :obj:`5`): +/// The length of the chunks to split the text into. +/// +/// Strings are split on the character level rather than the byte level to avoid +/// splitting unicode characters consisting of multiple bytes. +#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "FixedLength")] +pub struct PyFixedLength {} +#[pymethods] +impl PyFixedLength { + #[getter] + fn get_length(self_: PyRef) -> usize { + getter!(self_, FixedLength, length) + } + + #[setter] + fn set_length(self_: PyRef, length: usize) { + setter!(self_, FixedLength, length, length); + } + + #[new] + #[pyo3(signature = (length = 5), text_signature = "(self, length=5)")] + fn new(length: usize) -> (Self, PyPreTokenizer) { + (PyFixedLength {}, FixedLength::new(length).into()) + } +} + /// This pre-tokenizer splits on characters that belong to different language family /// It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt /// Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too. @@ -792,6 +828,7 @@ pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 80086f42e..a0fd830d9 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -8,6 +8,7 @@ ByteLevel, CharDelimiterSplit, Digits, + FixedLength, Metaspace, PreTokenizer, Punctuation, @@ -195,6 +196,21 @@ def test_can_modify(self): assert pretok.individual_digits == True +class TestFixedLength: + def test_instantiate(self): + assert FixedLength() is not None + assert isinstance(FixedLength(), PreTokenizer) + assert isinstance(FixedLength(), FixedLength) + assert isinstance(pickle.loads(pickle.dumps(FixedLength())), FixedLength) + + def test_can_modify(self): + pretok = FixedLength(length=5) + assert pretok.length == 5 + + pretok.length = 10 + assert pretok.length == 10 + + class TestUnicodeScripts: def test_instantiate(self): assert UnicodeScripts() is not None diff --git a/tokenizers/src/pre_tokenizers/fixed_length.rs b/tokenizers/src/pre_tokenizers/fixed_length.rs new file mode 100644 index 000000000..30456854a --- /dev/null +++ b/tokenizers/src/pre_tokenizers/fixed_length.rs @@ -0,0 +1,116 @@ +use crate::normalizer::Range; +use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; +use serde::{Deserialize, Serialize}; + +use crate::utils::macro_rules_attribute; + +#[derive(Clone, Debug, PartialEq, Eq)] +#[macro_rules_attribute(impl_serde_type!)] +pub struct FixedLength { + #[serde(default = "default_length")] + pub length: usize, +} + +impl FixedLength { + pub fn new(length: usize) -> Self { + Self { length } + } +} + +fn default_length() -> usize { + 5 +} + +impl PreTokenizer for FixedLength { + fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + pretokenized.split(|_, normalized| { + let text = normalized.get(); + if text.is_empty() { + return Ok(vec![]); + } + + let mut splits = Vec::new(); + let char_positions: Vec<_> = text.char_indices().collect(); + for chunk in char_positions.chunks(self.length) { + let start = chunk.first().map(|(i, _)| *i).unwrap_or(0); + let end = chunk.last().map(|(i, c)| i + c.len_utf8()).unwrap_or(text.len()); + splits.push(normalized.slice(Range::Normalized(start..end)) + .ok_or("Failed to slice normalized text")?); + } + + Ok(splits) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{OffsetReferential, OffsetType, PreTokenizer}; + + #[test] + fn basic() { + let tests = vec![ + ( + "Hello world", + vec![("Hello", (0, 5)), (" worl", (5, 10)), ("d", (10, 11))], + ), + ("Short", vec![("Short", (0, 5))]), + ("", vec![]), + ]; + let pretok = FixedLength { length: 5 }; + for (s, res) in tests { + let mut pretokenized = PreTokenizedString::from(s); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + res + ); + } + } + + #[test] + fn custom_length() { + let pretok = FixedLength { length: 3 }; + let mut pretokenized = PreTokenizedString::from("Hello world"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("Hel", (0, 3)), + ("lo ", (3, 6)), + ("wor", (6, 9)), + ("ld", (9, 11)), + ] + ); + } + + #[test] + fn utf8_characters() { + let pretok = FixedLength { length: 3 }; + let mut pretokenized = PreTokenizedString::from("Hello 👋 world"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + assert_eq!( + pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(), + vec![ + ("Hel", (0, 3)), + ("lo ", (3, 6)), + ("👋 w", (6, 12)), + ("orl", (12, 15)), + ("d", (15, 16)), + ] + ); + } +} diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 6195d170b..e885108fb 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -2,6 +2,7 @@ pub mod bert; pub mod byte_level; pub mod delimiter; pub mod digits; +pub mod fixed_length; pub mod metaspace; pub mod punctuation; pub mod sequence; @@ -15,6 +16,7 @@ use crate::pre_tokenizers::bert::BertPreTokenizer; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::delimiter::CharDelimiterSplit; use crate::pre_tokenizers::digits::Digits; +use crate::pre_tokenizers::fixed_length::FixedLength; use crate::pre_tokenizers::metaspace::Metaspace; use crate::pre_tokenizers::punctuation::Punctuation; use crate::pre_tokenizers::sequence::Sequence; @@ -37,6 +39,7 @@ pub enum PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + FixedLength(FixedLength), } impl PreTokenizer for PreTokenizerWrapper { @@ -53,6 +56,7 @@ impl PreTokenizer for PreTokenizerWrapper { Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized), Self::Digits(wspt) => wspt.pre_tokenize(normalized), Self::UnicodeScripts(us) => us.pre_tokenize(normalized), + Self::FixedLength(fl) => fl.pre_tokenize(normalized), } } } @@ -82,6 +86,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit, Digits, UnicodeScripts, + FixedLength, } #[derive(Deserialize)] @@ -105,6 +110,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + FixedLength(FixedLength), } let helper = PreTokenizerHelper::deserialize(deserializer)?; @@ -152,6 +158,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), + EnumType::FixedLength => PreTokenizerWrapper::FixedLength( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), } } @@ -187,6 +196,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => { PreTokenizerWrapper::UnicodeScripts(unicode_scripts) } + PreTokenizerUntagged::FixedLength(fixed_length) => { + PreTokenizerWrapper::FixedLength(fixed_length) + } } } }) @@ -204,6 +216,7 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace); impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit); impl_enum_from!(Digits, PreTokenizerWrapper, Digits); impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts); +impl_enum_from!(FixedLength, PreTokenizerWrapper, FixedLength); #[cfg(test)] mod tests {