Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: cleanup python backend #294

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

78 changes: 73 additions & 5 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
/// Single shard Client
use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient;
use crate::pb::embedding::v1::*;
use crate::Result;
use crate::{ClientError, Result};
use grpc_metadata::InjectTelemetryContext;
/// Single shard Client
use tokio::runtime::Runtime;
use tonic::transport::{Channel, Uri};
use tracing::instrument;

/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
pub struct AsyncClient {
stub: EmbeddingServiceClient<Channel>,
}

impl Client {
impl AsyncClient {
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Expand All @@ -23,7 +24,8 @@ impl Client {
}

/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
pub async fn connect_uds(path: &str) -> Result<Self> {
let path = path.to_owned();
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
Expand Down Expand Up @@ -65,3 +67,69 @@ impl Client {
Ok(response.embeddings)
}
}

#[derive(Debug)]
pub struct Client {
async_client: AsyncClient,
runtime: Runtime,
}

impl Client {
/// Returns a client connected to the given url
pub fn connect(uri: Uri) -> Result<Self> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| {
ClientError::Connection(format!("Could not start Tokio runtime: {err}"))
})?;

let async_client = runtime.block_on(AsyncClient::connect(uri))?;

Ok(Self {
async_client,
runtime,
})
}

/// Returns a client connected to the given unix socket
pub fn connect_uds(path: &str) -> Result<Self> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| {
ClientError::Connection(format!("Could not start Tokio runtime: {err}"))
})?;

let async_client = runtime.block_on(AsyncClient::connect_uds(path))?;

Ok(Self {
async_client,
runtime,
})
}

/// Get backend health
#[instrument(skip(self))]
pub fn health(&self) -> Result<HealthResponse> {
self.runtime.block_on(self.async_client.clone().health())
}

#[instrument(skip_all)]
pub fn embed(
&self,
input_ids: Vec<u32>,
token_type_ids: Vec<u32>,
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
) -> Result<Vec<Embedding>> {
self.runtime.block_on(self.async_client.clone().embed(
input_ids,
token_type_ids,
position_ids,
cu_seq_lengths,
max_length,
))
}
}
4 changes: 1 addition & 3 deletions backends/grpc-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ impl From<Status> for ClientError {

impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
Self::Connection(err.to_string())
}
}

Expand Down
1 change: 0 additions & 1 deletion backends/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ serde = { version = "^1.0", features = ["derive"] }
serde_json = "^1.0"
text-embeddings-backend-core = { path = "../core" }
thiserror = "^1.0"
tokio = { version = "^1.25", features = ["sync"] }
tracing = "^0.1"
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:

output = self.model(**kwargs)
embedding = output[0][:, 0]
cpu_results = embedding.view(-1).tolist()
cpu_results = embedding.reshape(-1).tolist()

return [
Embedding(
Expand Down
44 changes: 9 additions & 35 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
mod logging;
mod management;

use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
};
use tokio::runtime::Runtime;

pub struct PythonBackend {
_backend_process: management::BackendProcess,
tokio_runtime: Runtime,
backend_client: Client,
backend_process: management::BackendProcess,
}

impl PythonBackend {
Expand All @@ -38,39 +34,16 @@ impl PythonBackend {
}
};

let backend_process = management::BackendProcess::new(
model_path,
dtype,
&uds_path,
otlp_endpoint,
otlp_service_name,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?;
let backend_process =
management::BackendProcess::new(model_path, dtype, uds_path, otlp_endpoint, otlp_service_name)?;

let backend_client = tokio_runtime
.block_on(Client::connect_uds(uds_path))
.map_err(|err| {
BackendError::Start(format!("Could not connect to backend process: {err}"))
})?;

Ok(Self {
_backend_process: backend_process,
tokio_runtime,
backend_client,
})
Ok(Self { backend_process })
}
}

impl Backend for PythonBackend {
fn health(&self) -> Result<(), BackendError> {
if self
.tokio_runtime
.block_on(self.backend_client.clone().health())
.is_err()
{
if self.backend_process.client.health().is_err() {
return Err(BackendError::Unhealthy);
}
Ok(())
Expand All @@ -89,14 +62,15 @@ impl Backend for PythonBackend {
let batch_size = batch.len();

let results = self
.tokio_runtime
.block_on(self.backend_client.clone().embed(
.backend_process
.client
.embed(
batch.input_ids,
batch.token_type_ids,
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
))
)
.map_err(|err| BackendError::Inference(err.to_string()))?;
let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();

Expand Down
32 changes: 19 additions & 13 deletions backends/python/src/management.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::logging::log_lines;
use backend_grpc_client::Client;
use std::ffi::OsString;
use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt};
Expand All @@ -13,18 +14,19 @@ use text_embeddings_backend_core::BackendError;
#[derive(Debug)]
pub(crate) struct BackendProcess {
inner: Child,
pub client: Client,
}

impl BackendProcess {
pub(crate) fn new(
model_path: String,
dtype: String,
uds_path: &str,
uds_path: String,
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
let uds = Path::new(&uds_path);

// Clean previous runs
if uds.exists() {
Expand Down Expand Up @@ -87,7 +89,7 @@ impl BackendProcess {
let start_time = Instant::now();
let mut wait_time = Instant::now();

loop {
let client = loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
Expand All @@ -114,18 +116,22 @@ impl BackendProcess {
));
}

// Shard is ready
if uds.exists() {
tracing::info!("Python backend ready in {:?}", start_time.elapsed());
break;
} else if wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for Python backend to be ready...");
wait_time = Instant::now();
}
match Client::connect_uds(&uds_path) {
Ok(client) => {
tracing::info!("Python backend ready in {:?}", start_time.elapsed());
break client;
}
Err(_) if wait_time.elapsed() > Duration::from_secs(10) => {
tracing::info!("Waiting for Python backend to be ready...");
wait_time = Instant::now();
}
_ => {}
};

sleep(Duration::from_millis(5));
}
};

Ok(Self { inner: p })
Ok(Self { inner: p, client })
}
}

Expand Down
Loading