Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Length Pre-Tokenizer #1713

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ByteLevel = pre_tokenizers.ByteLevel
CharDelimiterSplit = pre_tokenizers.CharDelimiterSplit
Digits = pre_tokenizers.Digits
FixedLength = pre_tokenizers.FixedLength
Metaspace = pre_tokenizers.Metaspace
Punctuation = pre_tokenizers.Punctuation
Sequence = pre_tokenizers.Sequence
Expand Down
51 changes: 51 additions & 0 deletions bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,57 @@ class Digits(PreTokenizer):
"""
pass

class FixedLength(PreTokenizer):
"""
This pre-tokenizer splits the text into fixed length chunks

Args:
length (:obj:`int`, `optional`, defaults to :obj:`5`):
The length of the chunks to split the text into.

Strings are split on the character level rather than the byte level to avoid
splitting unicode characters consisting of multiple bytes.
"""
def __init__(self, length=5):
pass

def pre_tokenize(self, pretok):
"""
Pre-tokenize a :class:`~tokenizers.PyPreTokenizedString` in-place

This method allows to modify a :class:`~tokenizers.PreTokenizedString` to
keep track of the pre-tokenization, and leverage the capabilities of the
:class:`~tokenizers.PreTokenizedString`. If you just want to see the result of
the pre-tokenization of a raw string, you can use
:meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize_str`

Args:
pretok (:class:`~tokenizers.PreTokenizedString):
The pre-tokenized string on which to apply this
:class:`~tokenizers.pre_tokenizers.PreTokenizer`
"""
pass

def pre_tokenize_str(self, sequence):
"""
Pre tokenize the given string

This method provides a way to visualize the effect of a
:class:`~tokenizers.pre_tokenizers.PreTokenizer` but it does not keep track of the
alignment, nor does it provide all the capabilities of the
:class:`~tokenizers.PreTokenizedString`. If you need some of these, you can use
:meth:`~tokenizers.pre_tokenizers.PreTokenizer.pre_tokenize`

Args:
sequence (:obj:`str`):
A string to pre-tokeize

Returns:
:obj:`List[Tuple[str, Offsets]]`:
A list of tuple with the pre-tokenized parts and their offsets
"""
pass

class Metaspace(PreTokenizer):
"""
Metaspace pre-tokenizer
Expand Down
37 changes: 37 additions & 0 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::digits::Digits;
use tk::pre_tokenizers::fixed_length::FixedLength;
use tk::pre_tokenizers::metaspace::{Metaspace, PrependScheme};
use tk::pre_tokenizers::punctuation::Punctuation;
use tk::pre_tokenizers::split::Split;
Expand Down Expand Up @@ -113,6 +114,12 @@ impl PyPreTokenizer {
.into_any()
.into()
}
PreTokenizerWrapper::FixedLength(_) => {
Py::new(py, (PyFixedLength {}, base))?
.into_pyobject(py)?
.into_any()
.into()
}
},
}
}
Expand Down Expand Up @@ -627,6 +634,35 @@ impl PyDigits {
}
}

/// This pre-tokenizer splits the text into fixed length chunks
///
/// Args:
/// length (:obj:`int`, `optional`, defaults to :obj:`5`):
/// The length of the chunks to split the text into.
///
/// Strings are split on the character level rather than the byte level to avoid
/// splitting unicode characters consisting of multiple bytes.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "FixedLength")]
pub struct PyFixedLength {}
#[pymethods]
impl PyFixedLength {
#[getter]
fn get_length(self_: PyRef<Self>) -> usize {
getter!(self_, FixedLength, length)
}

#[setter]
fn set_length(self_: PyRef<Self>, length: usize) {
setter!(self_, FixedLength, length, length);
}

#[new]
#[pyo3(signature = (length = 5), text_signature = "(self, length=5)")]
fn new(length: usize) -> (Self, PyPreTokenizer) {
(PyFixedLength {}, FixedLength::new(length).into())
}
}

/// This pre-tokenizer splits on characters that belong to different language family
/// It roughly follows https://github.com/google/sentencepiece/blob/master/data/Scripts.txt
/// Actually Hiragana and Katakana are fused with Han, and 0x30FC is Han too.
Expand Down Expand Up @@ -792,6 +828,7 @@ pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PySequence>()?;
m.add_class::<PyDigits>()?;
m.add_class::<PyUnicodeScripts>()?;
m.add_class::<PyFixedLength>()?;
Ok(())
}

Expand Down
16 changes: 16 additions & 0 deletions bindings/python/tests/bindings/test_pre_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ByteLevel,
CharDelimiterSplit,
Digits,
FixedLength,
Metaspace,
PreTokenizer,
Punctuation,
Expand Down Expand Up @@ -195,6 +196,21 @@ def test_can_modify(self):
assert pretok.individual_digits == True


class TestFixedLength:
def test_instantiate(self):
assert FixedLength() is not None
assert isinstance(FixedLength(), PreTokenizer)
assert isinstance(FixedLength(), FixedLength)
assert isinstance(pickle.loads(pickle.dumps(FixedLength())), FixedLength)

def test_can_modify(self):
pretok = FixedLength(length=5)
assert pretok.length == 5

pretok.length = 10
assert pretok.length == 10

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd also want to make sure that it does it's job as a pretokenizer! so testing with the same string, that it splits in 5 then 10!


class TestUnicodeScripts:
def test_instantiate(self):
assert UnicodeScripts() is not None
Expand Down
116 changes: 116 additions & 0 deletions tokenizers/src/pre_tokenizers/fixed_length.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::normalizer::Range;
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
use serde::{Deserialize, Serialize};

use crate::utils::macro_rules_attribute;

#[derive(Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct FixedLength {
#[serde(default = "default_length")]
pub length: usize,
}

impl FixedLength {
pub fn new(length: usize) -> Self {
Self { length }
}
}

fn default_length() -> usize {
5
}

impl PreTokenizer for FixedLength {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, normalized| {
let text = normalized.get();
if text.is_empty() {
return Ok(vec![]);
}

let mut splits = Vec::new();
let char_positions: Vec<_> = text.char_indices().collect();
for chunk in char_positions.chunks(self.length) {
let start = chunk.first().map(|(i, _)| *i).unwrap_or(0);
let end = chunk.last().map(|(i, c)| i + c.len_utf8()).unwrap_or(text.len());
splits.push(normalized.slice(Range::Normalized(start..end))
.ok_or("Failed to slice normalized text")?);
}

Ok(splits)
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{OffsetReferential, OffsetType, PreTokenizer};

#[test]
fn basic() {
let tests = vec![
(
"Hello world",
vec![("Hello", (0, 5)), (" worl", (5, 10)), ("d", (10, 11))],
),
("Short", vec![("Short", (0, 5))]),
("", vec![]),
];
let pretok = FixedLength { length: 5 };
for (s, res) in tests {
let mut pretokenized = PreTokenizedString::from(s);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
res
);
}
}

#[test]
fn custom_length() {
let pretok = FixedLength { length: 3 };
let mut pretokenized = PreTokenizedString::from("Hello world");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hel", (0, 3)),
("lo ", (3, 6)),
("wor", (6, 9)),
("ld", (9, 11)),
]
);
}

#[test]
fn utf8_characters() {
let pretok = FixedLength { length: 3 };
let mut pretokenized = PreTokenizedString::from("Hello 👋 world");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hel", (0, 3)),
("lo ", (3, 6)),
("👋 w", (6, 12)),
("orl", (12, 15)),
("d", (15, 16)),
]
);
}
}
13 changes: 13 additions & 0 deletions tokenizers/src/pre_tokenizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod bert;
pub mod byte_level;
pub mod delimiter;
pub mod digits;
pub mod fixed_length;
pub mod metaspace;
pub mod punctuation;
pub mod sequence;
Expand All @@ -15,6 +16,7 @@ use crate::pre_tokenizers::bert::BertPreTokenizer;
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::delimiter::CharDelimiterSplit;
use crate::pre_tokenizers::digits::Digits;
use crate::pre_tokenizers::fixed_length::FixedLength;
use crate::pre_tokenizers::metaspace::Metaspace;
use crate::pre_tokenizers::punctuation::Punctuation;
use crate::pre_tokenizers::sequence::Sequence;
Expand All @@ -37,6 +39,7 @@ pub enum PreTokenizerWrapper {
WhitespaceSplit(WhitespaceSplit),
Digits(Digits),
UnicodeScripts(UnicodeScripts),
FixedLength(FixedLength),
}

impl PreTokenizer for PreTokenizerWrapper {
Expand All @@ -53,6 +56,7 @@ impl PreTokenizer for PreTokenizerWrapper {
Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
Self::Digits(wspt) => wspt.pre_tokenize(normalized),
Self::UnicodeScripts(us) => us.pre_tokenize(normalized),
Self::FixedLength(fl) => fl.pre_tokenize(normalized),
}
}
}
Expand Down Expand Up @@ -82,6 +86,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
WhitespaceSplit,
Digits,
UnicodeScripts,
FixedLength,
}

#[derive(Deserialize)]
Expand All @@ -105,6 +110,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
WhitespaceSplit(WhitespaceSplit),
Digits(Digits),
UnicodeScripts(UnicodeScripts),
FixedLength(FixedLength),
}

let helper = PreTokenizerHelper::deserialize(deserializer)?;
Expand Down Expand Up @@ -152,6 +158,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::FixedLength => PreTokenizerWrapper::FixedLength(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}

Expand Down Expand Up @@ -187,6 +196,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => {
PreTokenizerWrapper::UnicodeScripts(unicode_scripts)
}
PreTokenizerUntagged::FixedLength(fixed_length) => {
PreTokenizerWrapper::FixedLength(fixed_length)
}
}
}
})
Expand All @@ -204,6 +216,7 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
impl_enum_from!(FixedLength, PreTokenizerWrapper, FixedLength);

#[cfg(test)]
mod tests {
Expand Down