From fab3353abf482e0c4b15b9b2fe118c22b7e70aa2 Mon Sep 17 00:00:00 2001 From: DengFeng <1440416491@qq.com> Date: Thu, 23 May 2024 17:39:56 +0800 Subject: [PATCH] Optimize sorting using Quick-Merge sort --- libspu/kernel/hal/permute.cc | 214 ++++++++++++++++++++++++++++++++++- 1 file changed, 212 insertions(+), 2 deletions(-) diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index 657b81fc..836e7f1e 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -211,6 +211,213 @@ void HandleSmallArray(SPUContext *ctx, const CompFn &comparator_body, } } +void Partition(SPUContext *ctx, const CompFn &comparator_body, + absl::Span arr, + std::vector> &intervals) { + if (intervals.empty()) return; + + int64_t lo, hi, left, right, mid; + int64_t length = 0; + std::vector values; + + std::tie(lo, hi) = intervals.front(); + + left = lo + 1; + right = hi; + + Value pivots = + broadcast_to(ctx, slice_scalar_at(ctx, arr[0], {lo}), {right - left + 1}); + Value others = slice(ctx, arr[0], {left}, {right + 1}); + + Value confusion1 = + broadcast_to(ctx, slice_scalar_at(ctx, arr[1], {lo}), {right - left + 1}); + Value confusion2 = slice(ctx, arr[1], {left}, {right + 1}); + + for (auto item = intervals.begin() + 1; item != intervals.end(); item++) { + std::tie(lo, hi) = *item; + + left = lo + 1; + right = hi; + pivots = concatenate(ctx, + {pivots, + {broadcast_to(ctx, slice_scalar_at(ctx, arr[0], {lo}), + {right - left + 1})}}, + 0); + others = + concatenate(ctx, {others, slice(ctx, arr[0], {left}, {right + 1})}, 0); + + confusion1 = concatenate(ctx, + {confusion1, + {broadcast_to(ctx, slice_scalar_at(ctx, arr[1], {lo}), + {right - left + 1})}}, + 0); + confusion2 = + concatenate(ctx, {confusion2, slice(ctx, arr[1], {left}, {right + 1})}, 0); + } + + values.push_back(pivots); + values.push_back(others); + + values.push_back(confusion1); + values.push_back(confusion2); + + auto predicate = comparator_body(values); + auto _predicate = dump_public_as(ctx, hal::reveal(ctx, predicate)); + + Index lhs_indices; + Index rhs_indices; + Index pivot_indices; + Index mid_indices; + std::vector> pos; + + for (auto item : intervals) { + std::tie(lo, hi) = item; + + left = lo + 1; + right = hi; + auto offset = left; + + // use two pointer for partition + for (;;) { + while (right >= left && !_predicate[left - offset + length]) { + left++; + } + while (right >= left && _predicate[right - offset + length]) { + right--; + } + if (right < left) { + break; + } + + lhs_indices.emplace_back(left); + rhs_indices.emplace_back(right); + + left++; + right--; + } + length += hi - lo; + + pivot_indices.emplace_back(lo); + mid_indices.emplace_back(right); + pos.emplace_back(lo, right, hi); + } + Swap(arr, lhs_indices, rhs_indices); + // swap the pivot + Swap(arr, pivot_indices, mid_indices); + intervals.clear(); + + while (!pos.empty()) { + std::tie(lo, mid, hi) = pos.back(); + pos.pop_back(); + if (lo < mid) { + intervals.emplace_back(lo, mid - 1); + } + if (mid < hi) { + intervals.emplace_back(mid + 1, hi); + } + } +} + +void mergesort(SPUContext *ctx, const CompFn &comparator_body, + absl::Span arr, + std::vector> &intervals){ + const auto N = arr.front().numel(); + int64_t depth = ((std::log2(N) + 1) * std::log2(N)) / 2; + + std::vector lhs_indices(depth); + std::vector rhs_indices(depth); + + int64_t lo, hi, n, cnt; + for(auto item : intervals){ + std::tie(lo, hi) = item; + n = hi - lo + 1; + cnt = 0; + for (int64_t max_gap_in_stage = 1; max_gap_in_stage < n; + max_gap_in_stage += max_gap_in_stage) { + for (int64_t step = max_gap_in_stage; step > 0; step /= 2) { + for (int64_t j = step % max_gap_in_stage; j + step < n; + j += step + step) { + for (int64_t i = 0; i < step; i++) { + auto lhs_idx = i + j; + auto rhs_idx = i + j + step; + + if (rhs_idx >= n) break; + + auto range = max_gap_in_stage + max_gap_in_stage; + if (lhs_idx / range == rhs_idx / range) { + lhs_indices[cnt].emplace_back(lhs_idx + lo); + rhs_indices[cnt].emplace_back(rhs_idx + lo); + } + } + } + cnt += 1; + } + } + } + + for(int64_t i = 0; i < depth; i++){ + Index lhs_indice, rhs_indice; + size_t num_operands = arr.size(); + std::vector values; + + for(size_t j = 0; j < num_operands; ++j){ + values.emplace_back(arr[j].data().linear_gather(lhs_indices[i]), arr[j].dtype()); + values.emplace_back(arr[j].data().linear_gather(rhs_indices[i]), arr[j].dtype()); + } + auto predicate = comparator_body(values); + auto _predicate = dump_public_as(ctx, hal::reveal(ctx, predicate)); + for(size_t k = 0; k < _predicate.size(); k++){ + if(!_predicate[k]){ + lhs_indice.emplace_back(lhs_indices[i][k]); + rhs_indice.emplace_back(rhs_indices[i][k]); + } + } + Swap(arr, lhs_indice, rhs_indice); + } +} + +std::vector QuickMergesort(SPUContext *ctx, + const CompFn &comparator_body, + absl::Span inputs) { + std::vector ret; + for (auto const &input : inputs) { + if (!input.isSecret()) { + // we can not linear_scatter a secret value to a public operand + ret.emplace_back(_2s(ctx, input.clone()).setDtype(input.dtype())); + } else { + ret.emplace_back(input.clone()); + } + } + + const auto n = inputs.front().numel(); + std::vector> intervals; + intervals.emplace_back(0, n - 1); + int64_t quicksort_num = 0; + int64_t depth = 10; + + while (!intervals.empty()) { + Partition(ctx, comparator_body, absl::MakeSpan(ret), intervals); + quicksort_num += 1; + if(quicksort_num == depth) break; + } + + if(intervals.empty()) return ret; + + mergesort(ctx, comparator_body, absl::MakeSpan(ret), intervals); + + + return ret; +} + +std::vector PrepareSort(SPUContext *ctx, absl::Span inputs) { + std::vector inp; + auto rand_perm = _rand_perm_s(ctx, inputs.front().shape()); + inp.push_back(_perm_ss(ctx, inputs.front(), rand_perm).setDtype(inputs.front().dtype())); + inp.push_back(hal::random(ctx, Visibility::VIS_SECRET, DataType::DT_F64, + inputs.front().shape())); + return inp; +} + void TwoWayPartition(SPUContext *ctx, const CompFn &comparator_body, absl::Span arr, int64_t lo, int64_t hi, const TopKConfig &config, @@ -1023,7 +1230,9 @@ std::vector sort1d(SPUContext *ctx, SPU_ENFORCE(!is_stable, "Stable sort is unsupported if comparator return is secret."); - ret = internal::odd_even_merge_sort(ctx, cmp, inputs); + // ret = internal::odd_even_merge_sort(ctx, cmp, inputs); + auto inp = internal::PrepareSort(ctx, inputs); + ret = internal::QuickMergesort(ctx, cmp, absl::MakeSpan(inp)); } else { SPU_THROW("Should not reach here"); } @@ -1048,7 +1257,8 @@ std::vector simple_sort1d(SPUContext *ctx, SPU_ENFORCE(num_keys > 0 && num_keys <= static_cast(inputs.size()), "num_keys {} is not valid", num_keys); - bool fallback = false; + // bool fallback = false; + bool fallback = true; // if all keys are public, fallback to public sort if (std::all_of(inputs.begin(), inputs.begin() + num_keys, [](const spu::Value &v) { return v.isPublic(); })) {