Skip to content

Commit

Permalink
fix: normalizers deserialization and other refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Jan 9, 2025
1 parent 488a570 commit 1cbb741
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 49 deletions.
35 changes: 16 additions & 19 deletions bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[project]
name = 'tokenizers'
requires-python = '>=3.7'
version = '0.21.0'
authors = [
{name = 'Nicolas Patry', email = '[email protected]'},
{name = 'Anthony Moi', email = '[email protected]'}
{ name = 'Nicolas Patry', email = '[email protected]' },
{ name = 'Anthony Moi', email = '[email protected]' },
]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand All @@ -21,11 +22,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
keywords = ["NLP", "tokenizer", "BPE", "transformer", "deep learning"]
dynamic = [
'description',
'license',
'readme',
]
dynamic = ['description', 'license', 'readme']
dependencies = ["huggingface_hub>=0.16.4,<1.0"]

[project.urls]
Expand Down Expand Up @@ -57,16 +54,16 @@ target-version = ['py35']
line-length = 119
target-version = "py311"
lint.ignore = [
# a == None in tests vs is None.
"E711",
# a == False in tests vs is False.
"E712",
# try.. import except.. pattern without using the lib.
"F401",
# Raw type equality is required in asserts
"E721",
# Import order
"E402",
# Fixtures unused import
"F811",
# a == None in tests vs is None.
"E711",
# a == False in tests vs is False.
"E712",
# try.. import except.. pattern without using the lib.
"F401",
# Raw type equality is required in asserts
"E721",
# Import order
"E402",
# Fixtures unused import
"F811",
]
35 changes: 27 additions & 8 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PyNormalizedStringMut<'_> {
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
/// Normalizer will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyNormalizer {
pub(crate) normalizer: PyNormalizerTypeWrapper,
Expand Down Expand Up @@ -383,15 +383,14 @@ impl PySequence {
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
match &self_.as_ref().normalizer {
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(item.clone()))
.get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
PyNormalizerTypeWrapper::Single(inner) => {
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
.get_as_subtype(py)
PyNormalizer::new(PyNormalizerTypeWrapper::Single(inner.clone())).get_as_subtype(py)
}
}
}
Expand Down Expand Up @@ -600,13 +599,23 @@ impl Serialize for PyNormalizerWrapper {
}
}

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
#[derive(Debug, Clone)]
pub(crate) enum PyNormalizerTypeWrapper {
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
Single(Arc<RwLock<PyNormalizerWrapper>>),
}

impl<'de> Deserialize<'de> for PyNormalizerTypeWrapper {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wrapper = NormalizerWrapper::deserialize(deserializer)?;
let py_wrapper: PyNormalizerWrapper = wrapper.into();
Ok(py_wrapper.into())
}
}

impl Serialize for PyNormalizerTypeWrapper {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down Expand Up @@ -638,7 +647,17 @@ where
I: Into<PyNormalizerWrapper>,
{
fn from(norm: I) -> Self {
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
let norm = norm.into();
match norm {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(seq)) => {
PyNormalizerTypeWrapper::Sequence(
seq.into_iter()
.map(|e| Arc::new(RwLock::new(PyNormalizerWrapper::Wrapped(e.clone()))))
.collect(),
)
}
_ => PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm))),
}
}
}

Expand Down Expand Up @@ -761,7 +780,7 @@ mod test {
match normalizer.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
let normalizers = sequence.get_normalizers();
let normalizers = sequence.as_ref();
assert_eq!(normalizers.len(), 1);
match normalizers[0] {
NormalizerWrapper::NFKC(_) => {}
Expand Down
47 changes: 30 additions & 17 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use std::convert::TryInto;
use std::sync::Arc;
use std::sync::RwLock;

use crate::encoding::PyEncoding;
use crate::error::ToPyResult;
use pyo3::exceptions;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::*;
use std::ops::DerefMut;
use crate::encoding::PyEncoding;
use crate::error::ToPyResult;
use serde::{Deserialize, Serialize};
use tk::processors::bert::BertProcessing;
use tk::processors::byte_level::ByteLevel;
Expand Down Expand Up @@ -40,7 +40,7 @@ where
{
fn from(processor: I) -> Self {
PyPostProcessor {
processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc<RwLock<>>
processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc<RwLock<>>
}
}
}
Expand Down Expand Up @@ -76,7 +76,9 @@ impl PostProcessor for PyPostProcessor {
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Vec<Encoding>> {
self.processor.read().unwrap()
self.processor
.read()
.unwrap()
.process_encodings(encodings, add_special_tokens)
}
}
Expand Down Expand Up @@ -474,7 +476,7 @@ impl PyTemplateProcessing {
}

#[getter]
fn get_single(self_: PyRef<Self>) -> String{
fn get_single(self_: PyRef<Self>) -> String {
getter!(self_, Template, get_single())
}

Expand All @@ -484,7 +486,7 @@ impl PyTemplateProcessing {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
if let PostProcessorWrapper::Template(ref mut post) = *wrapper {
post.set_single(template.into());
post.set_single(template);
};
}
}
Expand All @@ -496,30 +498,35 @@ impl PyTemplateProcessing {
/// The processors that need to be chained
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
pub struct PySequence {}

#[pymethods]
impl PySequence {
#[new]
#[pyo3(signature = (processors_py), text_signature = "(self, processors)")]
fn new(processors_py: &Bound<'_, PyList>) -> (Self, PyPostProcessor) {
fn new(processors_py: &Bound<'_, PyList>) -> PyResult<(Self, PyPostProcessor)> {
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
for n in processors_py.iter() {
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
let processor = processor.processor.write().unwrap();
let processor: PyRef<PyPostProcessor> = n.extract()?;
let processor = processor
.processor
.write()
.map_err(|_| PyException::new_err("rwlock mutex is poisoned"))?;
processors.push(processor.clone());
}
let sequence_processor = Sequence::new(processors);
(
Ok((
PySequence {},
PyPostProcessor::new(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(sequence_processor)))),
)
PyPostProcessor::new(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(
sequence_processor,
)))),
))
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {

let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
// if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper {
Expand All @@ -533,7 +540,9 @@ impl PySequence {

match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py),
Some(item) => {
PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py)
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
Expand All @@ -544,7 +553,11 @@ impl PySequence {
}
}

fn __setitem__(self_: PyRefMut<'_, Self>, py: Python<'_>, index: usize, value: PyRef<'_, PyPostProcessor>) -> PyResult<()> {
fn __setitem__(
self_: PyRefMut<'_, Self>,
index: usize,
value: PyRef<'_, PyPostProcessor>,
) -> PyResult<()> {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
let value = value.processor.read().unwrap().clone();
Expand All @@ -561,7 +574,7 @@ impl PySequence {
"Index out of bounds",
))
}
},
}
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __setitem__",
)),
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[derive(Debug, Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub enum EnumType {
Bert,
Strip,
Expand Down Expand Up @@ -168,7 +168,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq),
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
Expand Down
17 changes: 15 additions & 2 deletions tokenizers/src/normalizers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,29 @@ impl Sequence {
pub fn new(normalizers: Vec<NormalizerWrapper>) -> Self {
Self { normalizers }
}
}

pub fn get_normalizers(&self) -> &[NormalizerWrapper] {
impl AsRef<[NormalizerWrapper]> for Sequence {
fn as_ref(&self) -> &[NormalizerWrapper] {
&self.normalizers
}
}

pub fn get_normalizers_mut(&mut self) -> &mut [NormalizerWrapper] {
impl AsMut<[NormalizerWrapper]> for Sequence {
fn as_mut(&mut self) -> &mut [NormalizerWrapper] {
&mut self.normalizers
}
}

impl IntoIterator for Sequence {
type Item = NormalizerWrapper;
type IntoIter = std::vec::IntoIter<Self::Item>;

fn into_iter(self) -> Self::IntoIter {
self.normalizers.into_iter()
}
}

impl Normalizer for Sequence {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
for normalizer in &self.normalizers {
Expand Down

0 comments on commit 1cbb741

Please sign in to comment.