Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented L1 distance & fixed L2 in binary quantization #21

Merged
merged 14 commits into from
Nov 27, 2023
147 changes: 144 additions & 3 deletions demos/benches/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ use rand::Rng;
#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_avx2::dot_avx;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_avx2::l1_avx;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::dot_sse;

fn encode_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode");
#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::l1_sse;

fn encode_dot_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode dot");

let vectors_count = 100_000;
let vector_dim = 1024;
Expand Down Expand Up @@ -145,10 +151,145 @@ fn encode_bench(c: &mut Criterion) {
});
}

fn encode_l1_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode l1");

let vectors_count = 100_000;
let vector_dim = 1024;
let mut rng = rand::thread_rng();
let mut list: Vec<f32> = Vec::new();
for _ in 0..vectors_count {
let vector: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
list.extend_from_slice(&vector);
}

let i8_encoded = EncodedVectorsU8::encode(
(0..vectors_count).map(|i| &list[i * vector_dim..(i + 1) * vector_dim]),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
count: vectors_count,
distance_type: DistanceType::L1,
invert: true,
},
None,
|| false,
)
.unwrap();

let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
let encoded_query = i8_encoded.encode_query(&query);

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 sse", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score all u8 neon", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all sse", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

let permutor = Permutor::new(vectors_count as u64);
let permutation: Vec<u32> = permutor.map(|i| i as u32).collect();

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for &i in &permutation {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 sse", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score random access u8 neon", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for &i in &permutation {
let i = i as usize;
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access sse", |b| {
let mut _s = 0.0;
b.iter(|| unsafe {
for &i in &permutation {
let i = i as usize;
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});
}

criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = encode_bench
targets = encode_dot_bench, encode_l1_bench
}

criterion_main!(benches);
46 changes: 46 additions & 0 deletions demos/src/metrics/utils_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,49 @@ pub unsafe fn dot_avx(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_avx(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m256 = _mm256_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub256_1: __m256 = _mm256_sub_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2));
sum256_1 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_1), sum256_1);

let sub256_2: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
sum256_2 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_2), sum256_2);

let sub256_3: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
sum256_3 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_3), sum256_3);

let sub256_4: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
sum256_4 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_4), sum256_4);

ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}

let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
42 changes: 42 additions & 0 deletions demos/src/metrics/utils_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,45 @@ pub unsafe fn dot_sse(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "sse4.1")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_sse(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m128 = _mm_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
sum128_1 = _mm_add_ps(_mm_andnot_ps(mask, sub128_1), sum128_1);

let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
sum128_2 = _mm_add_ps(_mm_andnot_ps(mask, sub128_2), sum128_2);

let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
sum128_3 = _mm_add_ps(_mm_andnot_ps(mask, sub128_3), sum128_3);

let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
sum128_4 = _mm_add_ps(_mm_andnot_ps(mask, sub128_4), sum128_4);

ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}

let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
54 changes: 54 additions & 0 deletions quantization/cpp/avx2.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdlib.h>
#include <stdint.h>
#include <immintrin.h>

Expand All @@ -12,6 +13,15 @@
R = _mm_cvtss_f32(x32); \
}

#define HSUM256_EPI32(X, R) \
int R = 0; \
{ \
__m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \
__m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \
__m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \
R = _mm_cvtsi128_si32(x32); \
}

EXPORT float impl_score_dot_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
Expand Down Expand Up @@ -49,3 +59,47 @@ EXPORT float impl_score_dot_avx(
HSUM256_PS(mul_ps, mul_scalar);
return mul_scalar;
}

EXPORT float impl_score_l1_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
uint32_t dim
) {
const __m256i* v_ptr = (const __m256i*)vector_ptr;
const __m256i* q_ptr = (const __m256i*)query_ptr;

uint32_t m = dim - (dim % 32);
__m256i sum256 = _mm256_setzero_si256();

for (uint32_t i = 0; i < m; i += 32) {
__m256i vec1 = _mm256_loadu_si256(v_ptr);
__m256i vec2 = _mm256_loadu_si256(q_ptr);
v_ptr++;
q_ptr++;

// Compute the difference in both directions and take the maximum for abs
__m256i diff1 = _mm256_subs_epu8(vec1, vec2);
__m256i diff2 = _mm256_subs_epu8(vec2, vec1);

__m256i abs_diff = _mm256_max_epu8(diff1, diff2);

__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());

sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}

__m256i sum_epi32 = _mm256_add_epi32(
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()),
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256()));

HSUM256_EPI32(sum_epi32, sum);

// Sum the remaining elements
for (uint32_t i = m; i < dim; ++i) {
kaancfidan marked this conversation as resolved.
Show resolved Hide resolved
sum += abs(query_ptr[i] - vector_ptr[i]);
}

return (float) sum;
}
45 changes: 45 additions & 0 deletions quantization/cpp/neon.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdlib.h>
#include <arm_neon.h>

#include "export_macro.h"
Expand Down Expand Up @@ -63,3 +64,47 @@ EXPORT uint64_t impl_xor_popcnt_neon(

return (uint64_t)vaddvq_u32(result);
}

EXPORT float impl_score_l1_neon(
const uint8_t * query_ptr,
const uint8_t * vector_ptr,
uint32_t dim
) {
const uint8_t* v_ptr = (const uint8_t*)vector_ptr;
const uint8_t* q_ptr = (const uint8_t*)query_ptr;

uint32_t m = dim - (dim % 16);
uint16x8_t sum16_low = vdupq_n_u16(0);
uint16x8_t sum16_high = vdupq_n_u16(0);

for (uint32_t i = 0; i < m; i += 16) {
uint8x16_t vec1 = vld1q_u8(v_ptr);
uint8x16_t vec2 = vld1q_u8(q_ptr);

uint8x16_t abs_diff = vabdq_u8(vec1, vec2);
uint16x8_t abs_diff16_low = vmovl_u8(vget_low_u8(abs_diff));
uint16x8_t abs_diff16_high = vmovl_u8(vget_high_u8(abs_diff));

sum16_low = vaddq_u16(sum16_low, abs_diff16_low);
sum16_high = vaddq_u16(sum16_high, abs_diff16_high);

v_ptr += 16;
q_ptr += 16;
}

// Horizontal sum of 16-bit integers
uint32x4_t sum32_low = vpaddlq_u16(sum16_low);
uint32x4_t sum32_high = vpaddlq_u16(sum16_high);
uint32x4_t sum32 = vaddq_u32(sum32_low, sum32_high);

uint32x2_t sum64_low = vadd_u32(vget_low_u32(sum32), vget_high_u32(sum32));
uint32x2_t sum64_high = vpadd_u32(sum64_low, sum64_low);
uint32_t sum = vget_lane_u32(sum64_high, 0);

// Sum the remaining elements
for (uint32_t i = m; i < dim; ++i) {
kaancfidan marked this conversation as resolved.
Show resolved Hide resolved
sum += abs(query_ptr[i] - vector_ptr[i]);
}

return (float) sum;
}
Loading
Loading