Skip to content

Commit

Permalink
Merge pull request #21 from kaancfidan/manhattan-distance
Browse files Browse the repository at this point in the history
implemented L1 distance & fixed L2 in binary quantization
  • Loading branch information
timvisee authored Nov 27, 2023
2 parents ff306d0 + 9fa1a78 commit 939fdb6
Show file tree
Hide file tree
Showing 17 changed files with 1,246 additions and 96 deletions.
146 changes: 140 additions & 6 deletions demos/benches/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use quantization::encoded_vectors_u8::EncodedVectorsU8;
use rand::Rng;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_avx2::dot_avx;

use demos::metrics::utils_avx2::{dot_avx, l1_avx};
#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::dot_sse;
use demos::metrics::utils_sse::{dot_sse, l1_sse};

fn encode_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode");
fn encode_dot_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode dot");

let vectors_count = 100_000;
let vector_dim = 1024;
Expand Down Expand Up @@ -145,10 +144,145 @@ fn encode_bench(c: &mut Criterion) {
});
}

fn encode_l1_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("encode l1");

let vectors_count = 100_000;
let vector_dim = 1024;
let mut rng = rand::thread_rng();
let mut list: Vec<f32> = Vec::new();
for _ in 0..vectors_count {
let vector: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
list.extend_from_slice(&vector);
}

let i8_encoded = EncodedVectorsU8::encode(
(0..vectors_count).map(|i| &list[i * vector_dim..(i + 1) * vector_dim]),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
count: vectors_count,
distance_type: DistanceType::L1,
invert: true,
},
None,
|| false,
)
.unwrap();

let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();
let encoded_query = i8_encoded.encode_query(&query);

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all u8 sse", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score all u8 neon", |b| {
b.iter(|| {
let mut _s = 0.0;
for i in 0..vectors_count as u32 {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score all sse", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for i in 0..vectors_count {
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

let permutor = Permutor::new(vectors_count as u64);
let permutation: Vec<u32> = permutor.map(|i| i as u32).collect();

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 avx", |b| {
b.iter(|| {
let mut _s = 0.0;
for &i in &permutation {
_s = i8_encoded.score_point_avx(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access u8 sse", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_sse(&encoded_query, i);
}
});
});

#[cfg(target_arch = "aarch64")]
group.bench_function("score random access u8 neon", |b| {
let mut _s = 0.0;
b.iter(|| {
for &i in &permutation {
_s = i8_encoded.score_point_neon(&encoded_query, i);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access avx", |b| {
b.iter(|| unsafe {
let mut _s = 0.0;
for &i in &permutation {
let i = i as usize;
_s = l1_avx(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});

#[cfg(target_arch = "x86_64")]
group.bench_function("score random access sse", |b| {
let mut _s = 0.0;
b.iter(|| unsafe {
for &i in &permutation {
let i = i as usize;
_s = l1_sse(&query, &list[i * vector_dim..(i + 1) * vector_dim]);
}
});
});
}

criterion_group! {
name = benches;
config = Criterion::default().sample_size(10);
targets = encode_bench
targets = encode_dot_bench, encode_l1_bench
}

criterion_main!(benches);
1 change: 0 additions & 1 deletion demos/src/ann_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use quantization::{EncodedVectorsU8, VectorParameters};

#[cfg(target_arch = "x86_64")]
use crate::metrics::utils_avx2::dot_avx;

#[cfg(target_arch = "x86_64")]
use demos::metrics::utils_sse::dot_sse;

Expand Down
46 changes: 46 additions & 0 deletions demos/src/metrics/utils_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,49 @@ pub unsafe fn dot_avx(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_avx(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m256 = _mm256_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub256_1: __m256 = _mm256_sub_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2));
sum256_1 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_1), sum256_1);

let sub256_2: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
sum256_2 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_2), sum256_2);

let sub256_3: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
sum256_3 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_3), sum256_3);

let sub256_4: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
sum256_4 = _mm256_add_ps(_mm256_andnot_ps(mask, sub256_4), sum256_4);

ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}

let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
42 changes: 42 additions & 0 deletions demos/src/metrics/utils_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,45 @@ pub unsafe fn dot_sse(v1: &[f32], v2: &[f32]) -> f32 {
}
result
}

#[target_feature(enable = "sse4.1")]
#[allow(clippy::missing_safety_doc, dead_code)]
pub unsafe fn l1_sse(v1: &[f32], v2: &[f32]) -> f32 {
let mask: __m128 = _mm_set1_ps(-0.0f32); // 1 << 31 used to clear sign bit to mimic abs

let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
sum128_1 = _mm_add_ps(_mm_andnot_ps(mask, sub128_1), sum128_1);

let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
sum128_2 = _mm_add_ps(_mm_andnot_ps(mask, sub128_2), sum128_2);

let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
sum128_3 = _mm_add_ps(_mm_andnot_ps(mask, sub128_3), sum128_3);

let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
sum128_4 = _mm_add_ps(_mm_andnot_ps(mask, sub128_4), sum128_4);

ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}

let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).abs();
}
-result
}
71 changes: 71 additions & 0 deletions quantization/cpp/avx2.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdlib.h>
#include <stdint.h>
#include <immintrin.h>

Expand All @@ -12,6 +13,15 @@
R = _mm_cvtss_f32(x32); \
}

#define HSUM256_EPI32(X, R) \
int R = 0; \
{ \
__m128i x128 = _mm_add_epi32(_mm256_extractf128_si256(X, 1), _mm256_castsi256_si128(X)); \
__m128i x64 = _mm_add_epi32(x128, _mm_srli_si128(x128, 8)); \
__m128i x32 = _mm_add_epi32(x64, _mm_srli_si128(x64, 4)); \
R = _mm_cvtsi128_si32(x32); \
}

EXPORT float impl_score_dot_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
Expand All @@ -34,6 +44,8 @@ EXPORT float impl_score_dot_avx(
mul1 = _mm256_add_epi32(mul1, s_low);
mul1 = _mm256_add_epi32(mul1, s_high);
}

// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (dim % 32 != 0) {
__m128i v_short = _mm_loadu_si128((const __m128i*)v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i*)q_ptr);
Expand All @@ -49,3 +61,62 @@ EXPORT float impl_score_dot_avx(
HSUM256_PS(mul_ps, mul_scalar);
return mul_scalar;
}

EXPORT float impl_score_l1_avx(
const uint8_t* query_ptr,
const uint8_t* vector_ptr,
uint32_t dim
) {
const __m256i* v_ptr = (const __m256i*)vector_ptr;
const __m256i* q_ptr = (const __m256i*)query_ptr;

uint32_t m = dim - (dim % 32);
__m256i sum256 = _mm256_setzero_si256();

for (uint32_t i = 0; i < m; i += 32) {
__m256i v = _mm256_loadu_si256(v_ptr);
__m256i q = _mm256_loadu_si256(q_ptr);
v_ptr++;
q_ptr++;

// Compute the difference in both directions and take the maximum for abs
__m256i diff1 = _mm256_subs_epu8(v, q);
__m256i diff2 = _mm256_subs_epu8(q, v);

__m256i abs_diff = _mm256_max_epu8(diff1, diff2);

__m256i abs_diff16_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
__m256i abs_diff16_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());

sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}

// the vector sizes are assumed to be multiples of 16, check if one last 16-element part remaining
if (m < dim) {
__m128i v_short = _mm_loadu_si128((const __m128i * ) v_ptr);
__m128i q_short = _mm_loadu_si128((const __m128i * ) q_ptr);

__m128i diff1 = _mm_subs_epu8(v_short, q_short);
__m128i diff2 = _mm_subs_epu8(q_short, v_short);

__m128i abs_diff = _mm_max_epu8(diff1, diff2);

__m128i abs_diff16_lo_128 = _mm_unpacklo_epi8(abs_diff, _mm_setzero_si128());
__m128i abs_diff16_hi_128 = _mm_unpackhi_epi8(abs_diff, _mm_setzero_si128());

__m256i abs_diff16_lo = _mm256_cvtepu16_epi32(abs_diff16_lo_128);
__m256i abs_diff16_hi = _mm256_cvtepu16_epi32(abs_diff16_hi_128);

sum256 = _mm256_add_epi16(sum256, abs_diff16_lo);
sum256 = _mm256_add_epi16(sum256, abs_diff16_hi);
}

__m256i sum_epi32 = _mm256_add_epi32(
_mm256_unpacklo_epi16(sum256, _mm256_setzero_si256()),
_mm256_unpackhi_epi16(sum256, _mm256_setzero_si256()));

HSUM256_EPI32(sum_epi32, sum);

return (float) sum;
}
Loading

0 comments on commit 939fdb6

Please sign in to comment.