Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support bitset filter for Brute Force #560

Open
wants to merge 13 commits into
base: branch-25.02
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
FORK rhdong
rhdong marked this conversation as resolved.
Show resolved Hide resolved
PINNED_TAG rhdong/bitset-to-csr-dev
ENABLE_MNMG_DEPENDENCIES OFF
ENABLE_NVTX OFF
USE_RAFT_STATIC ${CUVS_USE_RAFT_STATIC}
Expand Down
34 changes: 29 additions & 5 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cstdint>
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -456,8 +457,11 @@ inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;

namespace filtering {

enum class FilterType { None, Bitmap, Bitset };

struct base_filter {
virtual ~base_filter() = default;
virtual ~base_filter() = default;
virtual FilterType get_filter_type() const = 0;
};

/* A filter that filters nothing. This is the default behavior. */
Expand All @@ -475,6 +479,8 @@ struct none_sample_filter : public base_filter {
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::None; }
};

/**
Expand Down Expand Up @@ -513,15 +519,24 @@ struct ivf_to_sample_filter {
*/
template <typename bitmap_t, typename index_t>
struct bitmap_filter : public base_filter {
using view_t = cuvs::core::bitmap_view<bitmap_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_view_;
const view_t bitmap_view_;

bitmap_filter(const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering);
bitmap_filter(const view_t bitmap_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitmap; }

view_t view() const { return bitmap_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand All @@ -532,15 +547,24 @@ struct bitmap_filter : public base_filter {
*/
template <typename bitset_t, typename index_t>
struct bitset_filter : public base_filter {
using view_t = cuvs::core::bitset_view<bitset_t, index_t>;

// View of the bitset to use as a filter
const cuvs::core::bitset_view<bitset_t, index_t> bitset_view_;
const view_t bitset_view_;

bitset_filter(const cuvs::core::bitset_view<bitset_t, index_t> bitset_for_filtering);
bitset_filter(const view_t bitset_for_filtering);
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const;

FilterType get_filter_type() const override { return FilterType::Bitset; }

view_t view() const { return bitset_view_; }

template <typename csr_matrix_t>
void to_csr(raft::resources const& handle, csr_matrix_t& csr);
};

/**
Expand Down
123 changes: 74 additions & 49 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@

#include <cstdint>
#include <iostream>
#include <optional>
#include <set>
#include <variant>

namespace cuvs::neighbors::detail {

using namespace cuvs::neighbors::filtering;
/**
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
Expand All @@ -81,8 +85,9 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_col_tile_size = 0,
const DistanceT* precomputed_index_norms = nullptr,
const DistanceT* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op())
const uint32_t* filter_bits = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op(),
FilterType filter_type = FilterType::Bitmap)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -244,21 +249,23 @@ void tiled_brute_force_knn(const raft::resources& handle,
}
}

if (filter_bitmap != nullptr) {
auto distances_ptr = temp_distances.data();
auto count = thrust::make_counting_iterator<IndexType>(0);
DistanceT masked_distance = select_min ? std::numeric_limits<DistanceT>::infinity()
: std::numeric_limits<DistanceT>::lowest();
auto distances_ptr = temp_distances.data();
auto count = thrust::make_counting_iterator<IndexType>(0);
DistanceT masked_distance = select_min ? std::numeric_limits<DistanceT>::infinity()
: std::numeric_limits<DistanceT>::lowest();

if (filter_bits != nullptr) {
size_t n_cols = filter_type == FilterType::Bitmap ? n : 0;
thrust::for_each(raft::resource::get_thrust_policy(handle),
count,
count + current_query_size * current_centroid_size,
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
IndexType g_idx = row * n + col;
IndexType g_idx = row * n_cols + col;
IndexType item_idx = (g_idx) >> 5;
uint32_t bit_idx = (g_idx)&31;
uint32_t filter = filter_bitmap[item_idx];
uint32_t filter = filter_bits[item_idx];
if ((filter & (uint32_t(1) << bit_idx)) == 0) {
distances_ptr[idx] = masked_distance;
}
Expand Down Expand Up @@ -574,12 +581,12 @@ void brute_force_search(
query_norms ? query_norms->data_handle() : nullptr);
}

template <typename T, typename IdxT, typename BitmapT, typename DistanceT = float>
template <typename T, typename IdxT, typename BitsT, typename DistanceT = float>
void brute_force_search_filtered(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T, DistanceT>& idx,
raft::device_matrix_view<const T, IdxT, raft::row_major> queries,
cuvs::core::bitmap_view<const BitmapT, IdxT> filter,
const base_filter* filter,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors,
raft::device_matrix_view<DistanceT, IdxT, raft::row_major> distances,
std::optional<raft::device_vector_view<const DistanceT, IdxT>> query_norms = std::nullopt)
Expand All @@ -600,29 +607,40 @@ void brute_force_search_filtered(
metric == cuvs::distance::DistanceType::CosineExpanded),
"Index must has norms when using Euclidean, IP, and Cosine!");

IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);
IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);
FilterType filter_type = filter->get_filter_type();

auto stream = raft::resource::get_cuda_stream(res);

// calc nnz
IdxT nnz_h = 0;
rmm::device_scalar<IdxT> nnz(0, stream);
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data());
auto filter_view =
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements());
IdxT size_h = n_queries * n_dataset;
auto size_view = raft::make_host_scalar_view<const IdxT, IdxT>(&size_h);

raft::popc(res, filter_view, size_view, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);
std::optional<std::variant<const cuvs::core::bitmap_view<BitsT, IdxT>,
const cuvs::core::bitset_view<BitsT, IdxT>>>
filter_view;

IdxT nnz_h = 0;
float sparsity = 0.0f;

const BitsT* filter_data = nullptr;

if (filter_type == FilterType::Bitmap) {
auto actual_filter = dynamic_cast<const bitmap_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
} else if (filter_type == FilterType::Bitset) {
auto actual_filter = dynamic_cast<const bitset_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = n_queries * actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
} else {
RAFT_FAIL("Unsupported sample filter type");
}

raft::resource::sync_stream(res, stream);
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset));
std::visit([&](const auto& actual_view) { filter_data = actual_view.data(); }, *filter_view);

if (sparsity > 0.01f) {
if (sparsity < 0.9f) {
raft::resources stream_pool_handle(res);
raft::resource::set_cuda_stream(stream_pool_handle, stream);
auto idx_norm = idx.has_norms() ? const_cast<DistanceT*>(idx.norms().data_handle()) : nullptr;
Expand All @@ -642,12 +660,12 @@ void brute_force_search_filtered(
0,
idx_norm,
nullptr,
filter.data());
filter_data,
raft::identity_op(),
filter_type);
} else {
auto csr = raft::make_device_csr_matrix<DistanceT, IdxT>(res, n_queries, n_dataset, nnz_h);

// fill csr
raft::sparse::convert::bitmap_to_csr(res, filter, csr);
std::visit([&](const auto& actual_view) { actual_view.to_csr(res, csr); }, *filter_view);

// create filter csr view
auto compressed_csr_view = csr.structure_view();
Expand All @@ -663,7 +681,11 @@ void brute_force_search_filtered(
auto csr_view = raft::make_device_csr_matrix_view<DistanceT, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);

raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view);
std::visit(
[&](const auto& actual_view) {
raft::sparse::linalg::masked_matmul(res, queries, dataset_view, actual_view, csr_view);
},
*filter_view);

// post process
std::optional<raft::device_vector<DistanceT, IdxT>> query_norms_;
Expand Down Expand Up @@ -724,29 +746,32 @@ void search(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, LayoutT> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<DistT, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter_ref)
const base_filter& sample_filter_ref)
{
try {
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::none_sample_filter&>(sample_filter_ref);
auto& sample_filter = dynamic_cast<const none_sample_filter&>(sample_filter_ref);
return brute_force_search<T, int64_t, DistT>(res, idx, queries, neighbors, distances);
} catch (const std::bad_cast&) {
}
if constexpr (std::is_same_v<LayoutT, raft::col_major>) {
RAFT_FAIL("filtered search isn't available with col_major queries yet");
} else {
try {
auto& sample_filter =
dynamic_cast<const bitmap_filter<uint32_t, int64_t>&>(sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
}

try {
auto& sample_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<const uint32_t, int64_t>&>(
sample_filter_ref);
if constexpr (std::is_same_v<LayoutT, raft::col_major>) {
RAFT_FAIL("filtered search isn't available with col_major queries yet");
} else {
cuvs::core::bitmap_view<const uint32_t, int64_t> sample_filter_view =
sample_filter.bitmap_view_;
try {
auto& sample_filter =
dynamic_cast<const bitset_filter<uint32_t, int64_t>&>(sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, sample_filter_view, neighbors, distances);
res, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
}

Expand Down
16 changes: 16 additions & 0 deletions cpp/src/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/bitmap.cuh>
#include <raft/core/bitset.cuh>
#include <raft/core/detail/macros.hpp>
#include <raft/sparse/convert/csr.cuh>

#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -108,6 +109,13 @@ inline _RAFT_HOST_DEVICE bool bitset_filter<bitset_t, index_t>::operator()(
return bitset_view_.test(sample_ix);
}

template <typename bitset_t, typename index_t>
template <typename csr_matrix_t>
void bitset_filter<bitset_t, index_t>::to_csr(raft::resources const& handle, csr_matrix_t& csr)
{
raft::sparse::convert::bitset_to_csr(handle, bitset_view_, csr);
}

template <typename bitmap_t, typename index_t>
bitmap_filter<bitmap_t, index_t>::bitmap_filter(
const cuvs::core::bitmap_view<bitmap_t, index_t> bitmap_for_filtering)
Expand All @@ -124,4 +132,12 @@ inline _RAFT_HOST_DEVICE bool bitmap_filter<bitmap_t, index_t>::operator()(
{
return bitmap_view_.test(query_ix, sample_ix);
}

template <typename bitmap_t, typename index_t>
template <typename csr_matrix_t>
void bitmap_filter<bitmap_t, index_t>::to_csr(raft::resources const& handle, csr_matrix_t& csr)
{
raft::sparse::convert::bitmap_to_csr(handle, bitmap_view_, csr);
}

} // namespace cuvs::neighbors::filtering
Loading
Loading