Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented L1 distance & fixed L2 in binary quantization #21

Merged
merged 14 commits into from
Nov 27, 2023
146 changes: 140 additions & 6 deletions demos/benches/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use quantization::encoded_vectors_u8::EncodedVectorsU8;
use rand::Rng;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_avx2::dot_avx;

use demos::metrics::utils_avx2::{dot_avx, l1_avx};
#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::dot_sse;
use demos::metrics::utils_sse::{dot_sse, l1_sse};

fn encode_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode");
fn encode_dot_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode dot");

let vectors_count = 100_000;
let vector_dim = 1024;
Expand Down Expand Up @@ -145,10 +144,145 @@ fn encode_bench(c: &mut Criterion) {
});
}

fn encode_l1_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode l1");

let vectors_count = 100_000;
let vector_dim = 1024;
let mut rng = rand::thread_rng();
let mut list: Vec<f32> = Vec::new();
for _ in 0..vectors_count {
let vector: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
list.extend_from_slice(&vector);
}

let i8_encoded = EncodedVectorsU8::encode(
(0..vectors_count).map(|i| &list[i * vector_dim..(i + 1) * vector_dim]),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
count: vectors_count,
distance_type: DistanceType::L1,
invert: true,
},
None,
|| false,
)
.unwrap();

let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
let encoded_query = i8_encoded.encode_query(&query);

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 sse", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score all u8 neon", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all sse", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

let permutor = Permutor::new(vectors_count as u64);
let permutation: Vec<u32> = permutor.map(|i| i as u32).collect();

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for &i in &permutation {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 sse", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score random access u8 neon", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for &i in &permutation {
let i = i as usize;
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access sse", |b| {
let mut _s = 0.0;
b.iter(|| unsafe {
for &i in &permutation {
let i = i as usize;
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});
}

criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = encode_bench
targets = encode_dot_bench, encode_l1_bench
}

criterion_main!(benches);
1 change: 0 additions & 1 deletion demos/src/ann_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use quantization::{EncodedVectorsU8, VectorParameters};

#[cfg(target_arch = "x86_64")]
use crate::metrics::utils_avx2::dot_avx;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::dot_sse;

Expand Down
46 changes: 46 additions & 0 deletions demos/src/metrics/utils_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,49 @@ pub unsafe fn dot_avx(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_avx(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m256 = _mm256_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub256_1: __m256 = _mm256_sub_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2));
sum256_1 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_1), sum256_1);

let sub256_2: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
sum256_2 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_2), sum256_2);

let sub256_3: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
sum256_3 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_3), sum256_3);

let sub256_4: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
sum256_4 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_4), sum256_4);

ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}

let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
42 changes: 42 additions & 0 deletions demos/src/metrics/utils_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,45 @@ pub unsafe fn dot_sse(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "sse4.1")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_sse(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m128 = _mm_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
sum128_1 = _mm_add_ps(_mm_andnot_ps(mask, sub128_1), sum128_1);

let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
sum128_2 = _mm_add_ps(_mm_andnot_ps(mask, sub128_2), sum128_2);

let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
sum128_3 = _mm_add_ps(_mm_andnot_ps(mask, sub128_3), sum128_3);

let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
sum128_4 = _mm_add_ps(_mm_andnot_ps(mask, sub128_4), sum128_4);

ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}

let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
71 changes: 71 additions & 0 deletions quantization/cpp/avx2.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdlib.h>
#include <stdint.h>
#include <immintrin.h>

Expand All @@ -12,6 +13,15 @@
R = _mm_cvtss_f32(x32); \
}

#define HSUM256_EPI32(X, R) \
int R = 0; \
{ \
__m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \
__m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \
__m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \
R = _mm_cvtsi128_si32(x32); \
}

EXPORT float impl_score_dot_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
Expand All @@ -34,6 +44,8 @@ EXPORT float impl_score_dot_avx(
mul1 = _mm256_add_epi32(mul1, s_low);
mul1 = _mm256_add_epi32(mul1, s_high);
}

// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (dim % 32 != 0) {
__m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr);
Expand All @@ -49,3 +61,62 @@ EXPORT float impl_score_dot_avx(
HSUM256_PS(mul_ps, mul_scalar);
return mul_scalar;
}

EXPORT float impl_score_l1_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
uint32_t dim
) {
const __m256i* v_ptr = (const __m256i*)vector_ptr;
const __m256i* q_ptr = (const __m256i*)query_ptr;

uint32_t m = dim - (dim % 32);
__m256i sum256 = _mm256_setzero_si256();

for (uint32_t i = 0; i < m; i += 32) {
__m256i v = _mm256_loadu_si256(v_ptr);
__m256i q = _mm256_loadu_si256(q_ptr);
v_ptr++;
q_ptr++;

// Compute the difference in both directions and take the maximum for abs
__m256i diff1 = _mm256_subs_epu8(v, q);
__m256i diff2 = _mm256_subs_epu8(q, v);

__m256i abs_diff = _mm256_max_epu8(diff1, diff2);

__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());

sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}

// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (m < dim) {
__m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr);

__m128i diff1 = _mm_subs_epu8(v_short, q_short);
__m128i diff2 = _mm_subs_epu8(q_short, v_short);

__m128i abs_diff = _mm_max_epu8(diff1, diff2);

__m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128());
__m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128());

__m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128);
__m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128);

sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}

__m256i sum_epi32 = _mm256_add_epi32(
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()),
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256()));

HSUM256_EPI32(sum_epi32, sum);

return (float) sum;
}
Loading
Loading