Skip to content

Commit

Permalink
Faster Baby Bear multiplication on ARM/NEON (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbgl authored Mar 24, 2024
1 parent 2edbd19 commit ad1abf0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 93 deletions.
132 changes: 56 additions & 76 deletions baby-bear/src/aarch64_neon.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::arch::aarch64::{self, uint32x4_t};
use core::arch::aarch64::{self, int32x4_t, uint32x4_t};
use core::arch::asm;
use core::hint::unreachable_unchecked;
use core::iter::{Product, Sum};
Expand All @@ -13,8 +13,8 @@ use crate::BabyBear;

const WIDTH: usize = 4;
const P: uint32x4_t = unsafe { transmute::<[u32; WIDTH], _>([0x78000001; WIDTH]) };
const MU: uint32x4_t = unsafe { transmute::<[u32; WIDTH], _>([0x08000001; WIDTH]) };
const TOP_BIT: uint32x4_t = unsafe { transmute::<[u32; WIDTH], _>([0x80000000; WIDTH]) };
// This MU is the same 0x88000001 as elsewhere, just interpreted as an `i32`.
const MU: int32x4_t = unsafe { transmute::<[i32; WIDTH], _>([-0x77ffffff; WIDTH]) };

/// Vectorized NEON implementation of `BabyBear` arithmetic.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -170,96 +170,76 @@ fn add(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
}
}

/// Multiply two 31-bit numbers to obtain a 62-bit immediate result, and return the high 31 bits of
/// that result. Results are arbitrary if the inputs do not fit in 31 bits.
#[inline]
#[must_use]
fn mul_31x31_to_hi_31(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
// This is just a wrapper around `aarch64::vqdmulhq_s32`, so we don't have to worry about the
// casting elsewhere.
unsafe {
// Safety: If this code got compiled then NEON intrinsics are available.
aarch64::vreinterpretq_u32_s32(aarch64::vqdmulhq_s32(
aarch64::vreinterpretq_s32_u32(lhs),
aarch64::vreinterpretq_s32_u32(rhs),
))
}
}

// MONTGOMERY MULTIPLICATION
// This implementation is based on [1] but with minor changes. The reduction is as follows:
// This implementation is based on [1] but with changes. The reduction is as follows:
//
// Constants: P = 2^31 - 2^27 + 1
// B = 2^31
// mu = P^-1 mod B
// Input: 0 <= C < P B
// Output: 0 <= R < P such that R = C B^-1 (mod P)
// 1. Q := mu C mod B
// 2. T := (C - Q P) / B
// 3. R := if T < 0 then T + P else T
// B = 2^32
// μ = P^-1 mod B
// Input: -P^2 <= C <= P^2
// Output: -P < D < P such that D = C B^-1 (mod P)
// Define:
// smod_B(a) = r, where -B/2 <= r <= B/2 - 1 and r = a (mod B).
// Algorithm:
// 1. Q := smod_B(μ C)
// 2. D := (C - Q P) / B
//
// We first show that the division in step 2. is exact. It suffices to show that C = Q P (mod B). By
// definition of Q and mu, we have Q P = mu C P = P^-1 C P = C (mod B). We also have
// C - Q P = C (mod P), so thus T = C B^-1 (mod P).
// definition of Q, smod_B, and μ, we have Q P = smod_B(μ C) P = μ C P = P^-1 C P = C (mod B).
//
// It remains to show that R is in the correct range. It suffices to show that -P <= T < P. We know
// that 0 <= C < P B and 0 <= Q P < P B. Then -P B < C - QP < P B and -P < T < P, as desired.
// We also have C - Q P = C (mod P), so thus D = C B^-1 (mod P).
//
// In practice, we take advantage of the fact that C = Q P (mod B) to avoid a long multiplication
// when computing Q P: we only need the top half of the product. A more practical implementation is
// as follows:
// 1. Q := mu C mod B
// 2. T := C // B - Q P // B
// 3. R := if T < 0 then T + P else T
// "//" denotes truncated division.
// It remains to show that D is in the correct range. It suffices to show that -P B < C - Q P < P B.
// We know that -P^2 <= C <= P^2 and (-B / 2) P <= Q P <= (B/2 - 1) P. Then
// (1 - B/2) P - P^2 <= C - Q P <= (B/2) P + P^2. Now, P < B/2, so B/2 + P < B and
// (B/2) P + P^2 < P B; also B/2 - 1 + P < B, so -P B < (1 - B/2) P - P^2.
// Hence, -P B < C - Q P < P B as desired.
//
// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann, Cambridge University Press,
// 2010, algorithm 2.7.

/// Compute the high 31 bits of the long product. This is `C // B` in the description above.
#[inline]
#[must_use]
fn monty_mul_hi(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
fn mul(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
// We want this to compile to:
// sqdmulh res.4s, lhs.4s, rhs.4s
// throughput: .25 cyc/vec (16 els/cyc)
// latency: 3 cyc
mul_31x31_to_hi_31(lhs, rhs)
}
// sqdmulh c_hi.4s, lhs.4s, rhs.4s
// mul mu_rhs.4s, rhs.4s, MU.4s
// mul q.4s, lhs.4s, mu_rhs.4s
// sqdmulh qp_hi.4s, q.4s, P.4s
// shsub res.4s, c_hi.4s, qp_hi.4s
// cmgt underflow.4s, qp_hi.4s, c_hi.4s
// mls res.4s, underflow.4s, P.4s
// throughput: 1.75 cyc/vec (2.29 els/cyc)
// latency: (lhs->) 11 cyc, (rhs->) 14 cyc

/// Compute `Q P // B` in the description above.
#[allow(non_snake_case)]
#[inline]
#[must_use]
fn monty_mul_lo(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
// We want this to compile to:
// mul rhs_mu_mod_2pow32.rs, rhs.4s, MU.4s
// mul mu_C_mod_2pow32.rs, lhs.4s, rhs_mu_mod_2pow32.4s
// bic mu_C_mod_2pow31.rs, mu_C_mod_2pow32.rs, 0x80, lsl #24
// sqdmulh res.4s, mu_C_mod_2pow31.4s, P.4s
// throughput: 1 cyc/vec (4 els/cyc)
// latency: (1->1) 8 cyc
// (2->1) 11 cyc
unsafe {
// Safety: If this code got compiled then NEON intrinsics are available.
let rhs_mu_mod_2pow32 = aarch64::vmulq_u32(rhs, MU);
let mu_C_mod_2pow32 = aarch64::vmulq_u32(lhs, rhs_mu_mod_2pow32);
let mu_C_mod_2pow31 = aarch64::vbicq_u32(mu_C_mod_2pow32, TOP_BIT);
mul_31x31_to_hi_31(mu_C_mod_2pow31, P)
}
}
// No-op. The inputs are non-negative so we're free to interpret them as signed numbers.
let lhs = aarch64::vreinterpretq_s32_u32(lhs);
let rhs = aarch64::vreinterpretq_s32_u32(rhs);

/// Multiply vectors of Baby Bear field elements in canonical form.
/// If the inputs are not in canonical form, the result is undefined.
#[inline]
#[must_use]
fn mul(lhs: uint32x4_t, rhs: uint32x4_t) -> uint32x4_t {
// throughput: 2 cyc/vec (2 els/cyc)
// latency: (1->1) 13 cyc
// (2->1) 16 cyc
let hi = monty_mul_hi(lhs, rhs);
let lo = monty_mul_lo(lhs, rhs);
sub(hi, lo)
// Get bits 31, ..., 62 of C. Note that `sqdmulh` saturates when the product doesn't fit in
// an `i63`, but this cannot happen here due to our bounds on `lhs` and `rhs`.
let c_hi = aarch64::vqdmulhq_s32(lhs, rhs);

// Form `Q`, but reverse the order of multiplications to lower the latency on `lhs`.
let mu_rhs = aarch64::vmulq_s32(rhs, MU);
let q = aarch64::vmulq_s32(lhs, mu_rhs);

// Gets bits 31, ..., 62 of Q P. Again, saturation is not an issue because `P` is not
// -2**31.
let qp_hi = aarch64::vqdmulhq_s32(q, aarch64::vreinterpretq_s32_u32(P));

// Form D. Note that `c_hi` is C >> 31 and `qp_hi` is (Q P) >> 31, whereas we want
// (C - Q P) >> 32, so we need to subtract and divide by 2. Luckily NEON has an instruction
// for that! The lowest bit of `c_hi` and `qp_hi` is the same, so the division is exact.
let d = aarch64::vreinterpretq_u32_s32(aarch64::vhsubq_s32(c_hi, qp_hi));

// Finally we reduce D to canonical form. D is negative iff `c_hi > qp_hi`, so if that's the
// case then we add P. Note that if `c_hi > qp_hi` then `underflow` is -1, so we must
// _subtract_ `underflow` * P.
let underflow = aarch64::vcltq_s32(c_hi, qp_hi);
aarch64::vmlsq_u32(d, confuse_compiler(underflow), P)
}
}

/// Negate a vector of Baby Bear field elements in canonical form.
Expand Down
17 changes: 2 additions & 15 deletions baby-bear/src/baby_bear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,10 @@ use serde::{Deserialize, Deserializer, Serialize};

/// The Baby Bear prime
const P: u32 = 0x78000001;

// We want a different set of parameters on ARM/NEON than elsewhere. In particular, we want ARM to
// use 31 bits for the limb size, because that lets us use the SQDMULH instruction to do really fast
// multiplications in NEON. However, other architectures don't have this instruction, so 32-bit
// limbs are more convenient, being a nice power of 2.
const MONTY_BITS: u32 = if cfg!(all(target_arch = "aarch64", target_feature = "neon")) {
31
} else {
32
};
const MONTY_BITS: u32 = 32;
// We are defining MU = P^-1 (mod 2^MONTY_BITS). This is different from the usual convention
// (MU = -P^-1 (mod 2^MONTY_BITS)) but it avoids a carry.
const MONTY_MU: u32 = if cfg!(all(target_arch = "aarch64", target_feature = "neon")) {
0x08000001
} else {
0x88000001
};
const MONTY_MU: u32 = 0x88000001;

// This is derived from above.
const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32;
Expand Down
1 change: 0 additions & 1 deletion baby-bear/src/x86_64_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::BabyBear;

const WIDTH: usize = 8;
const P: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0x78000001; WIDTH]) };
// On x86 MONTY_BITS is always 32, so MU = P^-1 (mod 2^32) = 0x88000001.
const MU: __m256i = unsafe { transmute::<[u32; WIDTH], _>([0x88000001; WIDTH]) };

/// Vectorized AVX2 implementation of `BabyBear` arithmetic.
Expand Down
1 change: 0 additions & 1 deletion baby-bear/src/x86_64_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::BabyBear;

const WIDTH: usize = 16;
const P: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x78000001; WIDTH]) };
// On x86 MONTY_BITS is always 32, so MU = P^-1 (mod 2^32) = 0x88000001.
const MU: __m512i = unsafe { transmute::<[u32; WIDTH], _>([0x88000001; WIDTH]) };
const EVENS: __mmask16 = 0b0101010101010101;
const EVENS4: __mmask16 = 0x0f0f;
Expand Down

0 comments on commit ad1abf0

Please sign in to comment.