Skip to content

Commit

Permalink
AsRef slice for encoding input
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanPleshkov committed Apr 17, 2024
1 parent 939fdb6 commit a74376d
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 62 deletions.
2 changes: 1 addition & 1 deletion demos/benches/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn binary_bench(c: &mut Criterion) {
}

let encoded = EncodedVectorsBin::encode(
vectors.iter().map(|v| v.as_slice()),
vectors.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down
2 changes: 1 addition & 1 deletion demos/src/ann_benchmark_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ impl AnnBenchmarkData {
}
}

fn print_timings(timings: &mut Vec<f64>) {
fn print_timings(timings: &mut [f64]) {
if timings.is_empty() {
return;
}
Expand Down
2 changes: 1 addition & 1 deletion demos/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() {
let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();

let encoded = EncodedVectorsU8::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down
3 changes: 2 additions & 1 deletion quantization/src/encoded_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ impl DistanceType {
}

pub(crate) fn validate_vector_parameters<'a>(
data: impl Iterator<Item = &'a [f32]>,
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
vector_parameters: &VectorParameters,
) -> Result<(), EncodingError> {
let mut count = 0;
for vector in data {
let vector = vector.as_ref();
if vector.len() != vector_parameters.dim {
return Err(EncodingError::ArgumentsError(format!(
"Vector length {} does not match vector parameters dim {}",
Expand Down
4 changes: 2 additions & 2 deletions quantization/src/encoded_vectors_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct Metadata {

impl<TStorage: EncodedStorage> EncodedVectorsBin<TStorage> {
pub fn encode<'a>(
orig_data: impl Iterator<Item = &'a [f32]> + Clone,
orig_data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
mut storage_builder: impl EncodedStorageBuilder<TStorage>,
vector_parameters: &VectorParameters,
stop_condition: impl Fn() -> bool,
Expand All @@ -39,7 +39,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsBin<TStorage> {
return Err(EncodingError::Stopped);
}

let encoded_vector = Self::_encode_vector(vector);
let encoded_vector = Self::_encode_vector(vector.as_ref());
let encoded_vector_slice = encoded_vector.encoded_vector.as_slice();
let bytes = transmute_to_u8_slice(encoded_vector_slice);
storage_builder.push_vector_data(bytes);
Expand Down
18 changes: 12 additions & 6 deletions quantization/src/encoded_vectors_pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
/// * `max_threads` - Max allowed threads for kmeans and encodind process
/// * `stop_condition` - Function that returns `true` if encoding should be stopped
pub fn encode<'a>(
data: impl Iterator<Item = &'a [f32]> + Clone + Send,
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone + Send,
mut storage_builder: impl EncodedStorageBuilder<TStorage> + Send,
vector_parameters: &VectorParameters,
chunk_size: usize,
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
/// 'a is lifetime of vector in vector storage
/// 'b is lifetime of parent scope
fn encode_storage<'a: 'b, 'b>(
data: impl Iterator<Item = &'a [f32]> + Clone + Send + 'b,
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone + Send + 'b,
storage_builder: &'b mut (impl EncodedStorageBuilder<TStorage> + Send),
vector_division: &'b [Range<usize>],
centroids: &'b [Vec<f32>],
Expand Down Expand Up @@ -167,7 +167,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
/// This function should be called inside `rayon::scope`
fn encode_storage_rayon<'a: 'b, 'b>(
scope: &rayon::Scope<'b>,
data: impl Iterator<Item = &'a [f32]> + Clone + Send + 'b,
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone + Send + 'b,
storage_builder: &'b mut (impl EncodedStorageBuilder<TStorage> + Send),
vector_division: &'b [Range<usize>],
centroids: &'b [Vec<f32>],
Expand Down Expand Up @@ -199,7 +199,12 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
return;
}

Self::encode_vector(vector, vector_division, centroids, &mut encoded_vector);
Self::encode_vector(
vector.as_ref(),
vector_division,
centroids,
&mut encoded_vector,
);
// wait for permission from prev thread to use storage
let is_disconnected = condvar.wait();
// push encoded vector to storage
Expand Down Expand Up @@ -271,7 +276,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
/// * `max_kmeans_threads` - Max allowed threads for kmeans process
/// * `stop_condition` - Function that returns `true` if encoding should be stopped
fn find_centroids<'a>(
data: impl Iterator<Item = &'a [f32]> + Clone,
data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
vector_division: &[Range<usize>],
vector_parameters: &VectorParameters,
centroids_count: usize,
Expand All @@ -284,7 +289,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
// if there are not enough vectors, set centroids as point positions
if vector_parameters.count <= centroids_count {
for (i, vector_data) in data.into_iter().enumerate() {
result[i] = vector_data.to_vec();
result[i] = vector_data.as_ref().to_vec();
}
// fill empty centroids just with zeros
result[vector_parameters.count..centroids_count].fill(vec![0.0; vector_parameters.dim]);
Expand All @@ -307,6 +312,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsPQ<TStorage> {
let mut data_subset = Vec::with_capacity(sample_size * range.len());
let mut selected_index: usize = 0;
for (vector_index, vector_data) in data.clone().enumerate() {
let vector_data = vector_data.as_ref();
if vector_index == selected_vectors[selected_index] {
data_subset.extend_from_slice(&vector_data[range.clone()]);
selected_index += 1;
Expand Down
8 changes: 5 additions & 3 deletions quantization/src/encoded_vectors_u8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct Metadata {

impl<TStorage: EncodedStorage> EncodedVectorsU8<TStorage> {
pub fn encode<'a>(
orig_data: impl Iterator<Item = &'a [f32]> + Clone,
orig_data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
mut storage_builder: impl EncodedStorageBuilder<TStorage>,
vector_parameters: &VectorParameters,
quantile: Option<f32>,
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<TStorage: EncodedStorage> EncodedVectorsU8<TStorage> {

let mut encoded_vector = Vec::with_capacity(actual_dim + std::mem::size_of::<f32>());
encoded_vector.extend_from_slice(&f32::default().to_ne_bytes());
for &value in vector {
for &value in vector.as_ref() {
let encoded = Self::f32_to_u8(value, alpha, offset);
encoded_vector.push(encoded);
}
Expand Down Expand Up @@ -218,7 +218,9 @@ impl<TStorage: EncodedStorage> EncodedVectorsU8<TStorage> {
}
}

fn find_alpha_offset_size_dim<'a>(orig_data: impl Iterator<Item = &'a [f32]>) -> (f32, f32) {
fn find_alpha_offset_size_dim<'a>(
orig_data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
) -> (f32, f32) {
let (min, max) = find_min_max_from_iter(orig_data);
Self::alpha_offset_from_min_max(min, max)
}
Expand Down
10 changes: 6 additions & 4 deletions quantization/src/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use permutation_iterator::Permutor;

pub const QUANTILE_SAMPLE_SIZE: usize = 100_000;

pub(crate) fn find_min_max_from_iter<'a>(iter: impl Iterator<Item = &'a [f32]>) -> (f32, f32) {
pub(crate) fn find_min_max_from_iter<'a>(
iter: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
) -> (f32, f32) {
iter.fold((f32::MAX, f32::MIN), |(mut min, mut max), vector| {
for &value in vector {
for &value in vector.as_ref() {
if value < min {
min = value;
}
Expand All @@ -17,7 +19,7 @@ pub(crate) fn find_min_max_from_iter<'a>(iter: impl Iterator<Item = &'a [f32]>)
}

pub(crate) fn find_quantile_interval<'a>(
vector_data: impl Iterator<Item = &'a [f32]>,
vector_data: impl Iterator<Item = impl AsRef<[f32]> + 'a> + Clone,
dim: usize,
count: usize,
quantile: f32,
Expand All @@ -35,7 +37,7 @@ pub(crate) fn find_quantile_interval<'a>(
let mut selected_index: usize = 0;
for (vector_index, vector_data) in vector_data.into_iter().enumerate() {
if vector_index == selected_vectors[selected_index] {
data_slice.extend_from_slice(vector_data);
data_slice.extend_from_slice(vector_data.as_ref());
selected_index += 1;
if selected_index == slice_size {
break;
Expand Down
13 changes: 13 additions & 0 deletions quantization/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ pub fn transmute_to_u8_slice<T>(v: &[T]) -> &[u8] {

pub fn transmute_from_u8_to_slice<T>(data: &[u8]) -> &[T] {
debug_assert_eq!(data.len() % size_of::<T>(), 0);

assert_eq!(
data.as_ptr().align_offset(mem::align_of::<T>()),
0,
"transmuting byte slice 0x{:p} into slice of {}: \
required alignment is {} bytes, \
byte slice misaligned by {} bytes",
data.as_ptr(),
std::any::type_name::<T>(),
mem::align_of::<T>(),
data.as_ptr().align_offset(mem::align_of::<T>()),
);

let len = data.len() / size_of::<T>();
let ptr = data.as_ptr() as *const T;
unsafe { std::slice::from_raw_parts(ptr, len) }
Expand Down
4 changes: 2 additions & 2 deletions quantization/tests/empty_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod tests {
let vector_data: Vec<Vec<f32>> = Default::default();

let encoded = EncodedVectorsU8::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&vector_parameters,
None,
Expand Down Expand Up @@ -62,7 +62,7 @@ mod tests {
let vector_data: Vec<Vec<f32>> = Default::default();

let encoded = EncodedVectorsPQ::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&vector_parameters,
2,
Expand Down
4 changes: 2 additions & 2 deletions quantization/tests/stop_condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod tests {

assert!(
EncodedVectorsU8::encode(
(0..vector_parameters.count).map(|_| zero_vector.as_slice()),
(0..vector_parameters.count).map(|_| &zero_vector),
Vec::<u8>::new(),
&vector_parameters,
None,
Expand Down Expand Up @@ -73,7 +73,7 @@ mod tests {

assert!(
EncodedVectorsPQ::encode(
(0..vector_parameters.count).map(|_| zero_vector.as_slice()),
(0..vector_parameters.count).map(|_| &zero_vector),
Vec::<u8>::new(),
&vector_parameters,
2,
Expand Down
6 changes: 3 additions & 3 deletions quantization/tests/test_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod tests {
let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();

let encoded = EncodedVectorsU8::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -64,7 +64,7 @@ mod tests {
let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen()).collect();

let encoded = EncodedVectorsU8::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -101,7 +101,7 @@ mod tests {
let query: Vec<f32> = (0..vector_dim).map(|_| rng.gen_range(-1.0..=1.0)).collect();

let encoded = EncodedVectorsU8::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down
24 changes: 12 additions & 12 deletions quantization/tests/test_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -74,7 +74,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -110,7 +110,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -143,7 +143,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -175,7 +175,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -226,7 +226,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -277,7 +277,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -325,7 +325,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -373,7 +373,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -424,7 +424,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -475,7 +475,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down Expand Up @@ -523,7 +523,7 @@ mod tests {
}

let encoded = EncodedVectorsBin::encode(
vector_data.iter().map(|v| v.as_slice()),
vector_data.iter(),
Vec::<u8>::new(),
&VectorParameters {
dim: vector_dim,
Expand Down
Loading

0 comments on commit a74376d

Please sign in to comment.