Skip to content

Commit

Permalink
add grpc implem
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 28, 2024
1 parent ab7100a commit 0e2511b
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 43 deletions.
2 changes: 1 addition & 1 deletion core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl Infer {
inputs: I,
add_special_tokens: bool,
prompt_name: Option<String>,
) -> Result<RawEncoding, TextEmbeddingsError> {
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
self.tokenization
.tokenize(inputs.into(), add_special_tokens, prompt_name)
.await
Expand Down
83 changes: 56 additions & 27 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl Tokenization {
tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt_name: Option<String>,
default_prompt: Option<String>,
prompts: Option<HashMap<String, String>>,
) -> Self {
tracing::info!("Starting {workers} tokenization workers");
Expand All @@ -32,15 +32,15 @@ impl Tokenization {
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let receiver_clone = receiver.clone();
let default_prompt_name_clone = default_prompt_name.clone();
let default_prompt_clone = default_prompt.clone();
let prompts_clone = prompts.clone();
// Spawn worker
std::thread::spawn(move || {
tokenizer_worker(
tokenizer_clone,
max_input_length,
position_offset,
default_prompt_name_clone,
default_prompt_clone,
prompts_clone,
receiver_clone,
)
Expand Down Expand Up @@ -92,7 +92,7 @@ impl Tokenization {
inputs: EncodingInput,
add_special_tokens: bool,
prompt_name: Option<String>,
) -> Result<RawEncoding, TextEmbeddingsError> {
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
return Err(TextEmbeddingsError::Validation(
Expand Down Expand Up @@ -158,7 +158,7 @@ fn tokenizer_worker(
mut tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt_name: Option<String>,
default_prompt: Option<String>,
prompts: Option<HashMap<String, String>>,
receiver: async_channel::Receiver<TokenizerRequest>,
) {
Expand All @@ -174,9 +174,12 @@ fn tokenizer_worker(
parent_span,
) => {
parent_span.in_scope(|| {
let prompt_name = prompt_name.or(default_prompt_name.clone());

if !response_tx.is_closed() {
let default_prompt_clone = match prompt_name {
None => default_prompt.clone(),
Some(_) => None,
};

// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(encode_input(
Expand All @@ -185,6 +188,7 @@ fn tokenizer_worker(
truncation_direction,
max_input_length,
position_offset,
default_prompt_clone,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
Expand All @@ -199,16 +203,20 @@ fn tokenizer_worker(
response_tx,
parent_span,
) => {
let prompt_name = prompt_name.or(default_prompt_name.clone());

parent_span.in_scope(|| {
if !response_tx.is_closed() {
let default_prompt_clone = match prompt_name {
None => default_prompt.clone(),
Some(_) => None,
};

// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
let _ = response_tx.send(tokenize_input(
inputs,
add_special_tokens,
None,
default_prompt_clone,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
Expand Down Expand Up @@ -240,14 +248,11 @@ fn decode_ids(
.decode(&ids, skip_special_tokens)?)
}

fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
truncate_params: Option<TruncationParams>,
fn prepare_pre_prompt(
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<RawEncoding, TextEmbeddingsError> {
) -> Result<Option<String>, TextEmbeddingsError> {
let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() {
match prompts {
None => {
Expand All @@ -259,8 +264,21 @@ fn tokenize_input(
Some(prompts) => prompts.get(prompt_name).cloned(),
}
} else {
None
default_prompt
};
Ok(pre_prompt)
}

fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
truncate_params: Option<TruncationParams>,
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?;

let encoding = match inputs {
// encode input
Expand All @@ -272,9 +290,11 @@ fn tokenize_input(
s
};

tokenizer
let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<String>(s, add_special_tokens)?
.encode::<&str>(&s, add_special_tokens)?;

(Some(s), encoding)
}
EncodingInput::Dual(s1, s2) => {
if pre_prompt.is_some() {
Expand All @@ -283,25 +303,32 @@ fn tokenize_input(
));
}

tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2), add_special_tokens)?
(
None,
tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2), add_special_tokens)?,
)
}
// input is encoded -> convert to tokenizers Encoding
EncodingInput::Ids(ids) => {
if let Some(mut pre_prompt) = pre_prompt {
let text = tokenizer.decode(&ids, true)?;
pre_prompt.push_str(&text);

tokenizer
let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<String>(pre_prompt, false)?
.encode::<&str>(&pre_prompt, true)?;

(Some(pre_prompt), encoding)
} else {
let text = tokenizer.decode(&ids, false)?;

tokenizer
let encoding = tokenizer
.with_truncation(truncate_params)?
.encode::<String>(text, false)?
.encode::<&str>(&text, false)?;

(Some(text), encoding)
}
}
};
Expand All @@ -316,6 +343,7 @@ fn encode_input(
truncation_direction: TruncationDirection,
max_input_length: usize,
position_offset: usize,
default_prompt: Option<String>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
Expand All @@ -328,10 +356,11 @@ fn encode_input(
stride: 0,
});

let encoding = tokenize_input(
let (_, encoding) = tokenize_input(
inputs,
true,
truncate_params,
default_prompt,
prompt_name,
prompts,
tokenizer,
Expand Down Expand Up @@ -402,7 +431,7 @@ enum TokenizerRequest {
EncodingInput,
bool,
Option<String>,
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
oneshot::Sender<Result<(Option<String>, RawEncoding), TextEmbeddingsError>>,
Span,
),
Decode(
Expand Down
4 changes: 4 additions & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ message EmbedRequest {
bool truncate = 2;
bool normalize = 3;
TruncationDirection truncation_direction = 4;
optional string prompt_name = 5;
}

message EmbedResponse {
Expand All @@ -90,6 +91,7 @@ message EmbedSparseRequest {
string inputs = 1;
bool truncate = 2;
TruncationDirection truncation_direction = 3;
optional string prompt_name = 4;
}

message SparseValue {
Expand All @@ -106,6 +108,7 @@ message EmbedAllRequest {
string inputs = 1;
bool truncate = 2;
TruncationDirection truncation_direction = 3;
optional string prompt_name = 4;
}

message TokenEmbedding {
Expand Down Expand Up @@ -175,6 +178,7 @@ message RerankResponse {
message EncodeRequest {
string inputs = 1;
bool add_special_tokens = 2;
optional string prompt_name = 3;
}

message SimpleToken {
Expand Down
13 changes: 11 additions & 2 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl TextEmbeddingsService {
request.inputs,
request.truncate,
truncation_direction,
request.prompt_name,
request.normalize,
permit,
)
Expand Down Expand Up @@ -142,6 +143,7 @@ impl TextEmbeddingsService {
request.inputs,
request.truncate,
truncation_direction,
request.prompt_name,
permit,
)
.await
Expand Down Expand Up @@ -207,6 +209,7 @@ impl TextEmbeddingsService {
request.inputs,
request.truncate,
truncation_direction,
request.prompt_name,
permit,
)
.await
Expand Down Expand Up @@ -326,11 +329,17 @@ impl TextEmbeddingsService {
#[instrument(skip_all)]
async fn tokenize_inner(&self, request: EncodeRequest) -> Result<EncodeResponse, Status> {
let inputs = request.inputs;
let encoding = self
let (encoded_inputs, encoding) = self
.infer
.tokenize(inputs.clone(), request.add_special_tokens)
.tokenize(
inputs.clone(),
request.add_special_tokens,
request.prompt_name,
)
.await
.map_err(ErrorResponse::from)?;
let inputs = encoded_inputs.unwrap_or(inputs);

let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
Expand Down
4 changes: 3 additions & 1 deletion router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1195,10 +1195,12 @@ async fn tokenize(
add_special_tokens: bool,
prompt_name: Option<String>,
infer: Infer| async move {
let encoding = infer
let (encoded_input, encoding) = infer
.tokenize(input.clone(), add_special_tokens, prompt_name)
.await
.map_err(ErrorResponse::from)?;
let input = encoded_input.unwrap_or(input);

let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
Expand Down
15 changes: 9 additions & 6 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub async fn run(
max_batch_requests: Option<usize>,
max_client_batch_size: usize,
auto_truncate: bool,
default_prompt: Option<String>,
default_prompt_name: Option<String>,
hf_api_token: Option<String>,
hostname: Option<String>,
Expand Down Expand Up @@ -184,26 +185,28 @@ pub async fn run(
.context("Failed to parse `config_sentence_transformers.json`")?,
);
}
let prompts = new_st_config.map(|c| c.prompts);
if let Some(default_prompt_name) = default_prompt_name.as_ref() {
let prompts = new_st_config.and_then(|c| c.prompts);
let default_prompt = if let Some(default_prompt_name) = default_prompt_name.as_ref() {
match &prompts {
None => {
anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but no prompts were found in the Sentence Transformers configuration"));
}
Some(prompts) if !prompts.contains_key(default_prompt_name) => {
anyhow::bail!(format!("`default-prompt-name` is set to `{default_prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys()));
}
_ => (),
Some(prompts) => prompts.get(default_prompt_name).cloned(),
}
}
} else {
default_prompt
};

// Tokenization logic
let tokenization = Tokenization::new(
tokenization_workers,
tokenizer,
max_input_length,
position_offset,
default_prompt_name,
default_prompt,
prompts,
);

Expand Down Expand Up @@ -420,7 +423,7 @@ pub struct STConfig {

#[derive(Debug, Deserialize)]
pub struct NewSTConfig {
pub prompts: HashMap<String, String>,
pub prompts: Option<HashMap<String, String>>,
}

#[derive(Clone, Debug, Serialize)]
Expand Down
Loading

0 comments on commit 0e2511b

Please sign in to comment.