diff --git a/rig-surrealdb/examples/migrations.surql b/rig-surrealdb/examples/migrations.surql index 1f3d5190..8651f013 100644 --- a/rig-surrealdb/examples/migrations.surql +++ b/rig-surrealdb/examples/migrations.surql @@ -1,7 +1,6 @@ -- define table & fields DEFINE TABLE documents SCHEMAFULL; -DEFINE field docid on table documents type string; -DEFINE field document on table documents type string; +DEFINE field document on table documents type object; DEFINE field embedding on table documents type array; DEFINE field embedded_text on table documents type string; diff --git a/rig-surrealdb/examples/vector_search_surreal.rs b/rig-surrealdb/examples/vector_search_surreal.rs index 120c2c12..5300d869 100644 --- a/rig-surrealdb/examples/vector_search_surreal.rs +++ b/rig-surrealdb/examples/vector_search_surreal.rs @@ -11,7 +11,7 @@ struct WordDefinition { word: String, #[serde(skip)] // we don't want to serialize this field, we use only to create embeddings #[embed] - definitions: Vec, + definition: String, } impl std::fmt::Display for WordDefinition { @@ -49,26 +49,15 @@ async fn main() -> Result<(), anyhow::Error> { let words = vec![ WordDefinition { word: "flurbo".to_string(), - definitions: vec![ - "1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(), - "2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() - ] + definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() }, WordDefinition { - word: "glarb-glarb".to_string(), - definitions: vec![ - "1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() - ] + definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() }, WordDefinition { - word: "linglingdong".to_string(), - definitions: vec![ - "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - "2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() - ] + definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), }]; let documents = EmbeddingsBuilder::new(model.clone()) @@ -80,6 +69,7 @@ async fn main() -> Result<(), anyhow::Error> { // init vector store let vector_store = SurrealVectorStore::with_defaults(model, surreal); + vector_store.insert_documents(documents).await?; // query vector @@ -91,9 +81,9 @@ async fn main() -> Result<(), anyhow::Error> { for (distance, _id, doc) in results.iter() { println!("Result distance {} for word: {}", distance, doc); - // expected output (even if we have 2 entries on glarb-glarb the index only gives closest match) - // Result distance 0.2988549857990437 for word: glarb-glarb - //Result distance 0.7072261746390949 for word: linglingdong + // expected output + // Result distance 0.693218142100547 for word: glarb-glarb + // Result distance 0.2529120980283861 for word: linglingdong } Ok(()) diff --git a/rig-surrealdb/src/lib.rs b/rig-surrealdb/src/lib.rs index ae017f1f..dfa007ca 100644 --- a/rig-surrealdb/src/lib.rs +++ b/rig-surrealdb/src/lib.rs @@ -6,8 +6,7 @@ use rig::{ Embed, OneOrMany, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use surrealdb::{Connection, Surreal}; -use uuid::Uuid; +use surrealdb::{sql::Thing, Connection, Surreal}; pub use surrealdb::engine::remote::ws::{Ws, Wss}; @@ -41,14 +40,13 @@ impl Display for SurrealDistanceFunction { #[derive(Debug, Deserialize)] struct SearchResult { - docid: String, + id: Thing, document: String, distance: f64, } #[derive(Debug, Serialize, Deserialize)] pub struct CreateRecord { - docid: String, document: String, embedded_text: String, embedding: Vec, @@ -56,7 +54,7 @@ pub struct CreateRecord { #[derive(Debug, Deserialize)] pub struct SearchResultOnlyId { - docid: String, + id: Thing, distance: f64, } @@ -65,7 +63,7 @@ impl SearchResult { let document: T = serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?; - Ok((self.distance, self.docid.to_string(), document)) + Ok((self.distance, self.id.id.to_string(), document)) } } @@ -84,6 +82,10 @@ impl SurrealVectorStore { } } + pub fn inner_client(&self) -> &Surreal { + &self.surreal + } + pub fn with_defaults(model: Model, surreal: Surreal) -> Self { Self::new(model, surreal, None, SurrealDistanceFunction::Cosine) } @@ -106,10 +108,8 @@ impl SurrealVectorStore { } = self; format!( " - SELECT docid {document} {embedded_text}, {distance_function}($vec, embedding) as distance \ - FROM {documents_table} \ - GROUP BY docid \ - ORDER BY distance desc \ + SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \ + from {documents_table} order by distance desc \ LIMIT $limit", ) } @@ -121,14 +121,12 @@ impl SurrealVectorStore { for (document, embeddings) in documents { let json_document: serde_json::Value = serde_json::to_value(&document).unwrap(); let json_document_as_string = serde_json::to_string(&json_document).unwrap(); - let docid = Uuid::new_v4().to_string(); for embedding in embeddings { let embedded_text = embedding.document; let embedding: Vec = embedding.vec; let record = CreateRecord { - docid: docid.clone(), document: json_document_as_string.clone(), embedded_text, embedding, @@ -204,7 +202,7 @@ impl VectorStoreIndex for SurrealVectorSto .take::>(0) .unwrap() .into_iter() - .map(|row| (row.distance, row.docid.to_string())) + .map(|row| (row.distance, row.id.id.to_string())) .collect(); Ok(rows)