From 5978a66fde06ed967d252c4061ac9a4334a99259 Mon Sep 17 00:00:00 2001 From: Narek Galstyan Date: Thu, 14 Nov 2024 20:19:26 -0800 Subject: [PATCH] plrust api --- lantern_extras/src/bm25_api.rs | 149 +++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/lantern_extras/src/bm25_api.rs b/lantern_extras/src/bm25_api.rs index 5abf65f5..628096aa 100644 --- a/lantern_extras/src/bm25_api.rs +++ b/lantern_extras/src/bm25_api.rs @@ -1,7 +1,156 @@ use pgrx::extension_sql_file; +use pgrx::prelude::*; +use serde::Deserialize; +use std::collections::HashMap; + +use crate::bm25_agg::calculate_bm25; extension_sql_file!("./bm25_api.sql", requires = [Bloom]); +// CREATE OR REPLACE FUNCTION bm25_word_score( +// doc_ids bigint[], +// fqs integer[], +// doc_lens integer[], +// corpus_size bigint, +// avg_doc_len real, +// term text DEFAULT NULL, +// BM25_k1 real DEFAULT 1.2, +// BM25_b real DEFAULT 0.75 +// ) RETURNS TABLE (doc_id bigint, bm25 real) +// STRICT LANGUAGE plrust AS $$ +/// Calculate BM25 score for a given term and return the results as a table +/// The function takes per-term statistics and global statistics (corpus_size, avg_doc_len) and calculates BM25 score for each document +/// The functurns a table +#[pg_extern(immutable, parallel_safe)] +fn bm25_word_score( + doc_ids: pgrx::Array, + fqs: pgrx::Array, + doc_lens: pgrx::Array, + corpus_size: i64, + avg_doc_len: f32, + #[allow(unused_variables)] term: Option, + bm25_k1: default!(f32, 1.2), + bm25_b: default!(f32, 0.75), +) -> pgrx::iter::TableIterator<'static, (name!(doc_id, Option), name!(bm25, Option))> { + #[allow(non_snake_case)] + let BM25_k1 = bm25_k1; + #[allow(non_snake_case)] + let BM25_b = bm25_b; + + use std::collections::HashMap; + + let term_freq = doc_ids.len() as f32; + // Calculate BM25 score and return results using iterators + let results = doc_ids + .iter_deny_null() + .zip(fqs.iter_deny_null()) + .zip(doc_lens.iter_deny_null()) + .map(|((doc_id, fq), doc_len)| { + let doc_len = doc_len as f32; + let fq = fq as f32; + let bm25 = calculate_bm25( + doc_len, + fq, + term_freq, + corpus_size as u64, + avg_doc_len, + BM25_k1, + BM25_b, + ); + + (doc_id, bm25) + }) + .fold(HashMap::::new(), |mut acc, (doc_id, bm25)| { + acc.entry(doc_id) + .and_modify(|e| { + *e += bm25; + panic!("invariant volation: per-term doc_ids must be unique") + }) + .or_insert(bm25); + acc + }) + .into_iter() + .map(|(doc_id, bm25)| (Some(doc_id), Some(bm25))) + .collect::>(); + TableIterator::new(results.into_iter()) +} + +// CREATE OR REPLACE FUNCTION bm25_score( +// input_json JSON, +// limit_count integer, +// corpus_size bigint, +// avg_doc_len real, +// bm25_k1 real DEFAULT 1.2, +// bm25_b real DEFAULT 0.75 +// ) RETURNS TABLE (doc_id bigint, bm25 real) +// STRICT LANGUAGE plrust AS $$ +// [dependencies] +// serde = { version = "1.0", features = ["derive"] } +// serde_json = "1.0" +// [code] + +#[derive(Deserialize, Debug)] +struct InputData { + doc_ids: Vec, + fqs: Vec, + doc_lens: Vec, + term: String, +} + +#[pg_extern(immutable, parallel_safe)] +fn bm25_json_agg( + input_json: pgrx::Json, + limit_count: i32, + corpus_size: i64, + avg_doc_len: f32, + bm25_k1: default!(f32, 1.2), + bm25_b: default!(f32, 0.75), +) -> pgrx::iter::TableIterator<'static, (name!(doc_id, Option), name!(bm25, Option))> { + let BM25_k1 = bm25_k1; + let BM25_b = bm25_b; + + let input: Vec = + serde_json::from_value(input_json.0).expect("Failed to parse input JSON"); + + // Calculate BM25 score and return results using iterators + let mut results: HashMap = HashMap::new(); + + for input_data in input.iter() { + // doc_freq: number of docs containing the word + let doc_freq = input_data.doc_ids.len() as f32; + let idf: f32 = ((corpus_size as f32 - doc_freq + 0.5) / (doc_freq + 0.5)).ln(); // IDF calculation + + for (&doc_id, (&fq, &doc_len)) in input_data + .doc_ids + .iter() + .zip(input_data.fqs.iter().zip(input_data.doc_lens.iter())) + { + let doc_len = doc_len as f32; + let fq = fq as f32; + let bm25: f32 = idf + * ((fq * (BM25_k1 + 1.0)) + / (fq + BM25_k1 * (1.0 - BM25_b + BM25_b * (doc_len / avg_doc_len)))); + results + .entry(doc_id) + .and_modify(|e| *e += bm25) + .or_insert(bm25); + } + } + + let mut results: Vec<_> = results + .into_iter() + .map(|(doc_id, bm25)| (Some(doc_id), Some(bm25))) + .collect(); + + results.sort_unstable_by(|a, b| b.1.unwrap().partial_cmp(&a.1.unwrap()).unwrap()); + + if limit_count > 0 { + results.truncate(limit_count as usize); + } + + TableIterator::new(results.into_iter()) +} + #[cfg(any(test, feature = "pg_test"))] #[pgrx::pg_schema] pub mod tests {