diff --git a/src/bignum.nr b/src/bignum.nr index 5c27555f..f2ca559b 100644 --- a/src/bignum.nr +++ b/src/bignum.nr @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use crate::utils::map::map; use crate::params::BigNumParamsGetter; @@ -5,9 +6,10 @@ use crate::params::BigNumParamsGetter; use crate::fns::{ constrained_ops::{ add, assert_is_not_equal, conditional_select, derive_from_seed, div, eq, from_field, mul, - neg, sub, udiv, udiv_mod, umod, validate_in_field, validate_in_range, + neg, sub, udiv, udiv_mod, umod, validate_in_field, validate_in_range, cmp, }, expressions::{__compute_quadratic_expression, evaluate_quadratic_expression}, + serialization::{from_be_bytes, to_le_bytes}, unconstrained_ops::{ __add, __batch_invert, __batch_invert_slice, __derive_from_seed, __div, __eq, __invmod, @@ -377,3 +379,20 @@ where } } +impl std::cmp::Ord for BigNum +where + Params: BigNumParamsGetter, +{ + fn cmp(self, other: Self) -> Ordering { + let comp_result: bool = cmp::<_, MOD_BITS>(self.limbs, other.limbs); + if self.limbs == other.limbs { + Ordering::equal() + } else if comp_result { + // there was an underflow + Ordering::less() + } else { + // there was no underflow + Ordering::greater() + } + } +} diff --git a/src/fns/constrained_ops.nr b/src/fns/constrained_ops.nr index fcdf2147..f312fc33 100644 --- a/src/fns/constrained_ops.nr +++ b/src/fns/constrained_ops.nr @@ -2,9 +2,10 @@ use crate::fns::{ expressions::evaluate_quadratic_expression, unconstrained_helpers::{ __add_with_flags, __from_field, __neg_with_flags, __sub_with_flags, __validate_gt_remainder, - __validate_in_field_compute_borrow_flags, + __validate_in_field_compute_borrow_flags, __cmp_remainder, }, unconstrained_ops::{__div, __mul, __udiv_mod}, + }; use crate::params::BigNumParams as P; @@ -329,6 +330,58 @@ pub(crate) fn validate_gt(lhs: [Field; N], rhs: [ assert(result_limb == 0); } +pub(crate) fn cmp(lhs: [Field; N], rhs: [Field; N]) -> bool { + // so we do... p - x - r = 0 and there might be borrow flags + // a - b = r + // p + a - b - r = 0 + let (comp_result, result, carry_flags, borrow_flags) = unsafe { __cmp_remainder(lhs, rhs) }; + validate_in_range::<_, MOD_BITS>(result); + + let borrow_shift = 0x1000000000000000000000000000000; + let carry_shift = 0x1000000000000000000000000000000; + + let mut addend: [Field; N] = [0; N]; + if comp_result == true { + let result_limb = rhs[0] - lhs[0] + addend[0] - result[0] - 1 + + (borrow_flags[0] as Field * borrow_shift) + - (carry_flags[0] as Field * carry_shift); + assert(result_limb == 0); + + for i in 1..N - 1 { + let result_limb = rhs[i] - lhs[i] + addend[i] - result[i] - borrow_flags[i - 1] as Field + + carry_flags[i - 1] as Field + + ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift); + assert(result_limb == 0); + } + + let result_limb = rhs[N - 1] - lhs[N - 1] + addend[N - 1] + - result[N - 1] + - borrow_flags[N - 2] as Field + + carry_flags[N - 2] as Field; + assert(result_limb == 0); + true + } else { + let result_limb = lhs[0] - rhs[0] + addend[0] - result[0] - 1 + + (borrow_flags[0] as Field * borrow_shift) + - (carry_flags[0] as Field * carry_shift); + assert(result_limb == 0); + + for i in 1..N - 1 { + let result_limb = lhs[i] - rhs[i] + addend[i] - result[i] - borrow_flags[i - 1] as Field + + carry_flags[i - 1] as Field + + ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift); + assert(result_limb == 0); + } + + let result_limb = lhs[N - 1] - rhs[N - 1] + addend[N - 1] + - result[N - 1] + - borrow_flags[N - 2] as Field + + carry_flags[N - 2] as Field; + assert(result_limb == 0); + false + } +} + pub(crate) fn neg( params: P, val: [Field; N], diff --git a/src/fns/unconstrained_helpers.nr b/src/fns/unconstrained_helpers.nr index c900958a..d817ac60 100644 --- a/src/fns/unconstrained_helpers.nr +++ b/src/fns/unconstrained_helpers.nr @@ -46,37 +46,28 @@ pub(crate) unconstrained fn __validate_gt_remainder( let mut b_u60: U60Repr = From::from(rhs); let underflow = b_u60.gte(a_u60); - b_u60 += U60Repr::one(); assert(underflow == false, "BigNum::validate_gt check fails"); - let mut result_u60: U60Repr = U60Repr { limbs: [0; 2 * N] }; - - let mut carry_in: u64 = 0; - let mut borrow_in: u64 = 0; - let mut borrow_flags: [bool; N] = [false; N]; - let mut carry_flags: [bool; N] = [false; N]; - for i in 0..2 * N { - let mut add_term: u64 = a_u60.limbs[i] + carry_in; - let mut carry = (add_term >= TWO_POW_60) as u64; - add_term -= carry * TWO_POW_60; - carry_in = carry; + // calls a function that calcuates the lhs - rhs and the carry/borrow flags derived from it + let (result, carry_flags, borrow_flags) = __compute_carry_and_borrow::(a_u60, b_u60); + (result, carry_flags, borrow_flags) +} - let sub_term = b_u60.limbs[i] + borrow_in; - let mut borrow = (sub_term > add_term) as u64; - result_u60.limbs[i] = borrow * TWO_POW_60 + add_term - sub_term; +pub(crate) unconstrained fn __cmp_remainder( + lhs: [Field; N], + rhs: [Field; N], +) -> (bool, [Field; N], [bool; N], [bool; N]) { + let mut a_u60: U60Repr = U60Repr::from(lhs); + let mut b_u60: U60Repr = U60Repr::from(rhs); + let underflow = b_u60.gte(a_u60); - borrow_in = borrow; + let (a_u60, b_u60) = if underflow { + (b_u60, a_u60) + } else { + (a_u60, b_u60) + }; - if ((i & 1) == 1) { - if (carry & borrow == 1) { - carry = 0; - borrow = 0; - } - carry_flags[i / 2] = carry as bool; - borrow_flags[i / 2] = borrow as bool; - } - } - let result = U60Repr::into(result_u60); - (result, carry_flags, borrow_flags) + let (result, carry_flags, borrow_flags) = __compute_carry_and_borrow::(a_u60, b_u60); + (underflow, result, carry_flags, borrow_flags) } pub(crate) unconstrained fn __neg_with_flags( @@ -367,3 +358,40 @@ pub(crate) unconstrained fn __tonelli_shanks_sqrt_inner_loop_check( + a_u60: U60Repr, + mut b_u60: U60Repr, +) -> ([Field; N], [bool; N], [bool; N]) { + b_u60 += U60Repr::one(); + let mut result_u60: U60Repr = U60Repr { limbs: [0; 2 * N] }; + let mut carry_in: u64 = 0; + let mut borrow_in: u64 = 0; + let mut borrow_flags: [bool; N] = [false; N]; + let mut carry_flags: [bool; N] = [false; N]; + for i in 0..2 * N { + let mut add_term: u64 = a_u60.limbs[i] + carry_in; + let mut carry = (add_term >= TWO_POW_60) as u64; + add_term -= carry * TWO_POW_60; + carry_in = carry; + + let sub_term = b_u60.limbs[i] + borrow_in; + let mut borrow = (sub_term > add_term) as u64; + result_u60.limbs[i] = borrow * TWO_POW_60 + add_term - sub_term; + + borrow_in = borrow; + + if ((i & 1) == 1) { + if (carry & borrow == 1) { + carry = 0; + borrow = 0; + } + carry_flags[i / 2] = carry as bool; + borrow_flags[i / 2] = borrow as bool; + } + } + let result = U60Repr::into(result_u60); + (result, carry_flags, borrow_flags) +} diff --git a/src/runtime_bignum.nr b/src/runtime_bignum.nr index 64d48461..2ae3771f 100644 --- a/src/runtime_bignum.nr +++ b/src/runtime_bignum.nr @@ -1,10 +1,11 @@ use crate::params::BigNumParams; use crate::utils::map::map; +use std::cmp::Ordering; use crate::fns::{ constrained_ops::{ add, assert_is_not_equal, conditional_select, derive_from_seed, div, eq, mul, neg, sub, - udiv, udiv_mod, umod, validate_in_field, validate_in_range, + udiv, udiv_mod, umod, validate_in_field, validate_in_range, cmp, }, expressions::{__compute_quadratic_expression, evaluate_quadratic_expression}, serialization::{from_be_bytes, to_le_bytes}, @@ -429,3 +430,20 @@ impl std::cmp::Eq for RuntimeBigNum eq::<_, MOD_BITS>(params, self.limbs, other.limbs) } } + +impl Ord for RuntimeBigNum { + fn cmp(self, other: Self) -> std::cmp::Ordering { + let params = self.params; + assert(params == other.params); + let comp_result: bool = cmp::<_, MOD_BITS>(self.limbs, other.limbs); + if self.limbs == other.limbs { + Ordering::equal() + } else if comp_result { + // there was an underflow + Ordering::less() + } else { + // there was no underflow + Ordering::greater() + } + } +} diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index bc00bda5..8cd43ece 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -791,6 +791,48 @@ fn test_expressions() { } #[test] +fn test_cmp_BN() { + let mut a: Fq = BigNum::modulus(); + let mut b: Fq = BigNum::modulus(); + + a.limbs[0] -= 2; + b.limbs[0] -= 1; + + assert(a < b); +} + +#[test(should_fail_with = "Failed constraint")] +fn test_cmp_BN_fail() { + let mut a: Fq = BigNum::modulus(); + let mut b: Fq = BigNum::modulus(); + + a.limbs[0] -= 1; + b.limbs[0] -= 2; + + assert(a < b); +} + +#[test] +fn test_cmp_BN_2() { + let mut a: Fq = BigNum::modulus(); + let mut b: Fq = BigNum::modulus(); + + a.limbs[0] -= 1; + b.limbs[0] -= 2; + + assert(a > b); +} + +#[test(should_fail_with = "Failed constraint")] +fn test_cmp_BN_fail_2() { + let mut a: Fq = BigNum::modulus(); + let mut b: Fq = BigNum::modulus(); + + a.limbs[0] -= 2; + b.limbs[0] -= 1; + + assert(a > b); + fn test_from_field_1_digit() { let field: Field = 1; let result = Fq::from(field);