Skip to content

Commit

Permalink
refactor: simplify integration
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-mo-143 committed Feb 7, 2025
1 parent 7616b6d commit 4e9acd9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 33 deletions.
3 changes: 1 addition & 2 deletions rig-surrealdb/examples/migrations.surql
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-- define table & fields
DEFINE TABLE documents SCHEMAFULL;
DEFINE field docid on table documents type string;
DEFINE field document on table documents type string;
DEFINE field document on table documents type object;
DEFINE field embedding on table documents type array<float>;
DEFINE field embedded_text on table documents type string;

Expand Down
26 changes: 8 additions & 18 deletions rig-surrealdb/examples/vector_search_surreal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct WordDefinition {
word: String,
#[serde(skip)] // we don't want to serialize this field, we use only to create embeddings
#[embed]
definitions: Vec<String>,
definition: String,
}

impl std::fmt::Display for WordDefinition {
Expand Down Expand Up @@ -49,26 +49,15 @@ async fn main() -> Result<(), anyhow::Error> {
let words = vec![
WordDefinition {
word: "flurbo".to_string(),
definitions: vec![
"1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(),
"2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
},
WordDefinition {

word: "glarb-glarb".to_string(),
definitions: vec![
"1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
},
WordDefinition {

word: "linglingdong".to_string(),
definitions: vec![
"1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
"2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}];

let documents = EmbeddingsBuilder::new(model.clone())
Expand All @@ -80,6 +69,7 @@ async fn main() -> Result<(), anyhow::Error> {

// init vector store
let vector_store = SurrealVectorStore::with_defaults(model, surreal);

vector_store.insert_documents(documents).await?;

// query vector
Expand All @@ -91,9 +81,9 @@ async fn main() -> Result<(), anyhow::Error> {
for (distance, _id, doc) in results.iter() {
println!("Result distance {} for word: {}", distance, doc);

// expected output (even if we have 2 entries on glarb-glarb the index only gives closest match)
// Result distance 0.2988549857990437 for word: glarb-glarb
//Result distance 0.7072261746390949 for word: linglingdong
// expected output
// Result distance 0.693218142100547 for word: glarb-glarb
// Result distance 0.2529120980283861 for word: linglingdong
}

Ok(())
Expand Down
24 changes: 11 additions & 13 deletions rig-surrealdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use rig::{
Embed, OneOrMany,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use surrealdb::{Connection, Surreal};
use uuid::Uuid;
use surrealdb::{sql::Thing, Connection, Surreal};

pub use surrealdb::engine::remote::ws::{Ws, Wss};

Expand Down Expand Up @@ -41,22 +40,21 @@ impl Display for SurrealDistanceFunction {

#[derive(Debug, Deserialize)]
struct SearchResult {
docid: String,
id: Thing,
document: String,
distance: f64,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct CreateRecord {
docid: String,
document: String,
embedded_text: String,
embedding: Vec<f64>,
}

#[derive(Debug, Deserialize)]
pub struct SearchResultOnlyId {
docid: String,
id: Thing,
distance: f64,
}

Expand All @@ -65,7 +63,7 @@ impl SearchResult {
let document: T =
serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;

Ok((self.distance, self.docid.to_string(), document))
Ok((self.distance, self.id.id.to_string(), document))
}
}

Expand All @@ -84,6 +82,10 @@ impl<Model: EmbeddingModel, C: Connection> SurrealVectorStore<Model, C> {
}
}

pub fn inner_client(&self) -> &Surreal<C> {
&self.surreal
}

pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
}
Expand All @@ -106,10 +108,8 @@ impl<Model: EmbeddingModel, C: Connection> SurrealVectorStore<Model, C> {
} = self;
format!(
"
SELECT docid {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
FROM {documents_table} \
GROUP BY docid \
ORDER BY distance desc \
SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
from {documents_table} order by distance desc \
LIMIT $limit",
)
}
Expand All @@ -121,14 +121,12 @@ impl<Model: EmbeddingModel, C: Connection> SurrealVectorStore<Model, C> {
for (document, embeddings) in documents {
let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
let json_document_as_string = serde_json::to_string(&json_document).unwrap();
let docid = Uuid::new_v4().to_string();

for embedding in embeddings {
let embedded_text = embedding.document;
let embedding: Vec<f64> = embedding.vec;

let record = CreateRecord {
docid: docid.clone(),
document: json_document_as_string.clone(),
embedded_text,
embedding,
Expand Down Expand Up @@ -204,7 +202,7 @@ impl<Model: EmbeddingModel, C: Connection> VectorStoreIndex for SurrealVectorSto
.take::<Vec<SearchResultOnlyId>>(0)
.unwrap()
.into_iter()
.map(|row| (row.distance, row.docid.to_string()))
.map(|row| (row.distance, row.id.id.to_string()))
.collect();

Ok(rows)
Expand Down

0 comments on commit 4e9acd9

Please sign in to comment.