Skip to content

Commit

Permalink
Merge pull request #6 from kelvich/optional_python
Browse files Browse the repository at this point in the history
Do not build python bindings by default
  • Loading branch information
zurawiki authored Mar 15, 2023
2 parents ad8ab4b + b149033 commit d184b69
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tiktoken-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ bstr = "1.2.0"
fancy-regex = "0.11.0"
lazy_static = "1.4.0"
parking_lot = "0.12.1"
pyo3 = "0.18.1"
pyo3 = { version = "0.18.1", optional = true }
rustc-hash = "1.1.0"

[features]
python = ["dep:pyo3"] # build python bindings
68 changes: 67 additions & 1 deletion tiktoken-rs/src/vendor_tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ use std::thread;
use anyhow::anyhow;
use anyhow::Result;
use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;

#[cfg(feature = "python")]
use pyo3::exceptions;
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
use pyo3::types::{PyBytes, PyList, PyTuple};
#[cfg(feature = "python")]
use pyo3::PyResult;
use rustc_hash::FxHashMap as HashMap;

fn _byte_pair_merge<T>(
piece: &[u8],
Expand Down Expand Up @@ -180,6 +185,7 @@ fn hash_current_thread() -> usize {
}

const MAX_NUM_THREADS: usize = 128;
#[cfg(feature = "python")]
#[pyclass]
#[derive(Clone)]
pub struct CoreBPE {
Expand All @@ -192,6 +198,19 @@ pub struct CoreBPE {
sorted_token_bytes: Vec<Vec<u8>>,
}

#[cfg(not(feature = "python"))]
#[derive(Clone)]
#[allow(dead_code)] // sorted_token_bytes is used but is read only in python version
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
sorted_token_bytes: Vec<Vec<u8>>,
}

impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
// See performance notes above for what this is about
Expand Down Expand Up @@ -450,6 +469,51 @@ impl CoreBPE {
// Encoding
// ====================

// This function a copy of the similar function in python API, but it return
// Rust's results and errors
#[cfg(not(feature = "python"))]
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> Result<Self> {
let regex = Regex::new(pattern).map_err(|e| anyhow!(e.to_string()))?;

let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|")).map_err(|e| anyhow!(e.to_string()))?
};

let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();

assert!(encoder.len() == decoder.len());

let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();

// Clone because I don't know how to tell Rust I'm not going to change the map
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();

Ok(CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
})
}

pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
self._encode_ordinary_native(text)
}
Expand Down Expand Up @@ -479,6 +543,7 @@ impl CoreBPE {
}
}

#[cfg(feature = "python")]
#[pymethods]
impl CoreBPE {
#[new]
Expand Down Expand Up @@ -593,6 +658,7 @@ impl CoreBPE {
}
}

#[cfg(feature = "python")]
#[pymodule]
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<CoreBPE>()?;
Expand Down

0 comments on commit d184b69

Please sign in to comment.