Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize sorting using Quick-Merge sort
Browse files Browse the repository at this point in the history
loveDengFeng committed May 23, 2024
1 parent ac80021 commit fab3353
Showing 1 changed file with 212 additions and 2 deletions.
214 changes: 212 additions & 2 deletions libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
@@ -211,6 +211,213 @@ void HandleSmallArray(SPUContext *ctx, const CompFn &comparator_body,
}
}

void Partition(SPUContext *ctx, const CompFn &comparator_body,
absl::Span<spu::Value> arr,
std::vector<std::pair<int64_t, int64_t>> &intervals) {
if (intervals.empty()) return;

int64_t lo, hi, left, right, mid;
int64_t length = 0;
std::vector<Value> values;

std::tie(lo, hi) = intervals.front();

left = lo + 1;
right = hi;

Value pivots =
broadcast_to(ctx, slice_scalar_at(ctx, arr[0], {lo}), {right - left + 1});
Value others = slice(ctx, arr[0], {left}, {right + 1});

Value confusion1 =
broadcast_to(ctx, slice_scalar_at(ctx, arr[1], {lo}), {right - left + 1});
Value confusion2 = slice(ctx, arr[1], {left}, {right + 1});

for (auto item = intervals.begin() + 1; item != intervals.end(); item++) {
std::tie(lo, hi) = *item;

left = lo + 1;
right = hi;
pivots = concatenate(ctx,
{pivots,
{broadcast_to(ctx, slice_scalar_at(ctx, arr[0], {lo}),
{right - left + 1})}},
0);
others =
concatenate(ctx, {others, slice(ctx, arr[0], {left}, {right + 1})}, 0);

confusion1 = concatenate(ctx,
{confusion1,
{broadcast_to(ctx, slice_scalar_at(ctx, arr[1], {lo}),
{right - left + 1})}},
0);
confusion2 =
concatenate(ctx, {confusion2, slice(ctx, arr[1], {left}, {right + 1})}, 0);
}

values.push_back(pivots);
values.push_back(others);

values.push_back(confusion1);
values.push_back(confusion2);

auto predicate = comparator_body(values);
auto _predicate = dump_public_as<bool>(ctx, hal::reveal(ctx, predicate));

Index lhs_indices;
Index rhs_indices;
Index pivot_indices;
Index mid_indices;
std::vector<std::tuple<int64_t, int64_t, int64_t>> pos;

for (auto item : intervals) {
std::tie(lo, hi) = item;

left = lo + 1;
right = hi;
auto offset = left;

// use two pointer for partition
for (;;) {
while (right >= left && !_predicate[left - offset + length]) {
left++;
}
while (right >= left && _predicate[right - offset + length]) {
right--;
}
if (right < left) {
break;
}

lhs_indices.emplace_back(left);
rhs_indices.emplace_back(right);

left++;
right--;
}
length += hi - lo;

pivot_indices.emplace_back(lo);
mid_indices.emplace_back(right);
pos.emplace_back(lo, right, hi);
}
Swap(arr, lhs_indices, rhs_indices);
// swap the pivot
Swap(arr, pivot_indices, mid_indices);
intervals.clear();

while (!pos.empty()) {
std::tie(lo, mid, hi) = pos.back();
pos.pop_back();
if (lo < mid) {
intervals.emplace_back(lo, mid - 1);
}
if (mid < hi) {
intervals.emplace_back(mid + 1, hi);
}
}
}

void mergesort(SPUContext *ctx, const CompFn &comparator_body,
absl::Span<spu::Value> arr,
std::vector<std::pair<int64_t, int64_t>> &intervals){
const auto N = arr.front().numel();
int64_t depth = ((std::log2(N) + 1) * std::log2(N)) / 2;

std::vector<Index> lhs_indices(depth);
std::vector<Index> rhs_indices(depth);

int64_t lo, hi, n, cnt;
for(auto item : intervals){
std::tie(lo, hi) = item;
n = hi - lo + 1;
cnt = 0;
for (int64_t max_gap_in_stage = 1; max_gap_in_stage < n;
max_gap_in_stage += max_gap_in_stage) {
for (int64_t step = max_gap_in_stage; step > 0; step /= 2) {
for (int64_t j = step % max_gap_in_stage; j + step < n;
j += step + step) {
for (int64_t i = 0; i < step; i++) {
auto lhs_idx = i + j;
auto rhs_idx = i + j + step;

if (rhs_idx >= n) break;

auto range = max_gap_in_stage + max_gap_in_stage;
if (lhs_idx / range == rhs_idx / range) {
lhs_indices[cnt].emplace_back(lhs_idx + lo);
rhs_indices[cnt].emplace_back(rhs_idx + lo);
}
}
}
cnt += 1;
}
}
}

for(int64_t i = 0; i < depth; i++){
Index lhs_indice, rhs_indice;
size_t num_operands = arr.size();
std::vector<spu::Value> values;

for(size_t j = 0; j < num_operands; ++j){
values.emplace_back(arr[j].data().linear_gather(lhs_indices[i]), arr[j].dtype());
values.emplace_back(arr[j].data().linear_gather(rhs_indices[i]), arr[j].dtype());
}
auto predicate = comparator_body(values);
auto _predicate = dump_public_as<bool>(ctx, hal::reveal(ctx, predicate));
for(size_t k = 0; k < _predicate.size(); k++){
if(!_predicate[k]){
lhs_indice.emplace_back(lhs_indices[i][k]);
rhs_indice.emplace_back(rhs_indices[i][k]);
}
}
Swap(arr, lhs_indice, rhs_indice);
}
}

std::vector<spu::Value> QuickMergesort(SPUContext *ctx,
const CompFn &comparator_body,
absl::Span<spu::Value const> inputs) {
std::vector<spu::Value> ret;
for (auto const &input : inputs) {
if (!input.isSecret()) {
// we can not linear_scatter a secret value to a public operand
ret.emplace_back(_2s(ctx, input.clone()).setDtype(input.dtype()));
} else {
ret.emplace_back(input.clone());
}
}

const auto n = inputs.front().numel();
std::vector<std::pair<int64_t, int64_t>> intervals;
intervals.emplace_back(0, n - 1);
int64_t quicksort_num = 0;
int64_t depth = 10;

while (!intervals.empty()) {
Partition(ctx, comparator_body, absl::MakeSpan(ret), intervals);
quicksort_num += 1;
if(quicksort_num == depth) break;
}

if(intervals.empty()) return ret;

mergesort(ctx, comparator_body, absl::MakeSpan(ret), intervals);


return ret;
}

std::vector<spu::Value> PrepareSort(SPUContext *ctx, absl::Span<spu::Value const> inputs) {
std::vector<spu::Value> inp;
auto rand_perm = _rand_perm_s(ctx, inputs.front().shape());
inp.push_back(_perm_ss(ctx, inputs.front(), rand_perm).setDtype(inputs.front().dtype()));
inp.push_back(hal::random(ctx, Visibility::VIS_SECRET, DataType::DT_F64,
inputs.front().shape()));
return inp;
}

void TwoWayPartition(SPUContext *ctx, const CompFn &comparator_body,
absl::Span<spu::Value> arr, int64_t lo, int64_t hi,
const TopKConfig &config,
@@ -1023,7 +1230,9 @@ std::vector<spu::Value> sort1d(SPUContext *ctx,
SPU_ENFORCE(!is_stable,
"Stable sort is unsupported if comparator return is secret.");

ret = internal::odd_even_merge_sort(ctx, cmp, inputs);
// ret = internal::odd_even_merge_sort(ctx, cmp, inputs);
auto inp = internal::PrepareSort(ctx, inputs);
ret = internal::QuickMergesort(ctx, cmp, absl::MakeSpan(inp));
} else {
SPU_THROW("Should not reach here");
}
@@ -1048,7 +1257,8 @@ std::vector<spu::Value> simple_sort1d(SPUContext *ctx,
SPU_ENFORCE(num_keys > 0 && num_keys <= static_cast<int64_t>(inputs.size()),
"num_keys {} is not valid", num_keys);

bool fallback = false;
// bool fallback = false;
bool fallback = true;
// if all keys are public, fallback to public sort
if (std::all_of(inputs.begin(), inputs.begin() + num_keys,
[](const spu::Value &v) { return v.isPublic(); })) {

0 comments on commit fab3353

Please sign in to comment.