Skip to content

Commit

Permalink
plrust api
Browse files Browse the repository at this point in the history
  • Loading branch information
Ngalstyan4 committed Nov 15, 2024
1 parent c07ddc4 commit 5978a66
Showing 1 changed file with 149 additions and 0 deletions.
149 changes: 149 additions & 0 deletions lantern_extras/src/bm25_api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,156 @@
use pgrx::extension_sql_file;
use pgrx::prelude::*;
use serde::Deserialize;
use std::collections::HashMap;

use crate::bm25_agg::calculate_bm25;

extension_sql_file!("./bm25_api.sql", requires = [Bloom]);

// CREATE OR REPLACE FUNCTION bm25_word_score(
// doc_ids bigint[],
// fqs integer[],
// doc_lens integer[],
// corpus_size bigint,
// avg_doc_len real,
// term text DEFAULT NULL,
// BM25_k1 real DEFAULT 1.2,
// BM25_b real DEFAULT 0.75
// ) RETURNS TABLE (doc_id bigint, bm25 real)
// STRICT LANGUAGE plrust AS $$
/// Calculate BM25 score for a given term and return the results as a table
/// The function takes per-term statistics and global statistics (corpus_size, avg_doc_len) and calculates BM25 score for each document
/// The functurns a table
#[pg_extern(immutable, parallel_safe)]
fn bm25_word_score(
doc_ids: pgrx::Array<i64>,
fqs: pgrx::Array<i32>,
doc_lens: pgrx::Array<i32>,
corpus_size: i64,
avg_doc_len: f32,
#[allow(unused_variables)] term: Option<String>,
bm25_k1: default!(f32, 1.2),
bm25_b: default!(f32, 0.75),
) -> pgrx::iter::TableIterator<'static, (name!(doc_id, Option<i64>), name!(bm25, Option<f32>))> {
#[allow(non_snake_case)]
let BM25_k1 = bm25_k1;
#[allow(non_snake_case)]
let BM25_b = bm25_b;

use std::collections::HashMap;

let term_freq = doc_ids.len() as f32;
// Calculate BM25 score and return results using iterators
let results = doc_ids
.iter_deny_null()
.zip(fqs.iter_deny_null())
.zip(doc_lens.iter_deny_null())
.map(|((doc_id, fq), doc_len)| {
let doc_len = doc_len as f32;
let fq = fq as f32;
let bm25 = calculate_bm25(
doc_len,
fq,
term_freq,
corpus_size as u64,
avg_doc_len,
BM25_k1,
BM25_b,
);

(doc_id, bm25)
})
.fold(HashMap::<i64, f32>::new(), |mut acc, (doc_id, bm25)| {
acc.entry(doc_id)
.and_modify(|e| {
*e += bm25;
panic!("invariant volation: per-term doc_ids must be unique")
})
.or_insert(bm25);
acc
})
.into_iter()
.map(|(doc_id, bm25)| (Some(doc_id), Some(bm25)))
.collect::<Vec<_>>();
TableIterator::new(results.into_iter())
}

// CREATE OR REPLACE FUNCTION bm25_score(
// input_json JSON,
// limit_count integer,
// corpus_size bigint,
// avg_doc_len real,
// bm25_k1 real DEFAULT 1.2,
// bm25_b real DEFAULT 0.75
// ) RETURNS TABLE (doc_id bigint, bm25 real)
// STRICT LANGUAGE plrust AS $$
// [dependencies]
// serde = { version = "1.0", features = ["derive"] }
// serde_json = "1.0"
// [code]

#[derive(Deserialize, Debug)]
struct InputData {
doc_ids: Vec<i64>,
fqs: Vec<i32>,
doc_lens: Vec<i32>,
term: String,
}

#[pg_extern(immutable, parallel_safe)]
fn bm25_json_agg(
input_json: pgrx::Json,
limit_count: i32,
corpus_size: i64,
avg_doc_len: f32,
bm25_k1: default!(f32, 1.2),
bm25_b: default!(f32, 0.75),
) -> pgrx::iter::TableIterator<'static, (name!(doc_id, Option<i64>), name!(bm25, Option<f32>))> {
let BM25_k1 = bm25_k1;
let BM25_b = bm25_b;

let input: Vec<InputData> =
serde_json::from_value(input_json.0).expect("Failed to parse input JSON");

// Calculate BM25 score and return results using iterators
let mut results: HashMap<i64, f32> = HashMap::new();

for input_data in input.iter() {
// doc_freq: number of docs containing the word
let doc_freq = input_data.doc_ids.len() as f32;
let idf: f32 = ((corpus_size as f32 - doc_freq + 0.5) / (doc_freq + 0.5)).ln(); // IDF calculation

for (&doc_id, (&fq, &doc_len)) in input_data
.doc_ids
.iter()
.zip(input_data.fqs.iter().zip(input_data.doc_lens.iter()))
{
let doc_len = doc_len as f32;
let fq = fq as f32;
let bm25: f32 = idf
* ((fq * (BM25_k1 + 1.0))
/ (fq + BM25_k1 * (1.0 - BM25_b + BM25_b * (doc_len / avg_doc_len))));
results
.entry(doc_id)
.and_modify(|e| *e += bm25)
.or_insert(bm25);
}
}

let mut results: Vec<_> = results
.into_iter()
.map(|(doc_id, bm25)| (Some(doc_id), Some(bm25)))
.collect();

results.sort_unstable_by(|a, b| b.1.unwrap().partial_cmp(&a.1.unwrap()).unwrap());

if limit_count > 0 {
results.truncate(limit_count as usize);
}

TableIterator::new(results.into_iter())
}

#[cfg(any(test, feature = "pg_test"))]
#[pgrx::pg_schema]
pub mod tests {
Expand Down

0 comments on commit 5978a66

Please sign in to comment.