Skip to content

Commit

Permalink
decimal: improve QuoRem performance
Browse files Browse the repository at this point in the history
  • Loading branch information
eapenkin committed Jun 29, 2024
1 parent d0c79a0 commit a7ad10d
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 35 deletions.
10 changes: 10 additions & 0 deletions coefficient.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ func (x fint) quo(y fint) (z fint, ok bool) {
return z, true
}

// quoRem calculates x div y and x mod y.
func (x fint) quoRem(y fint) (q, r fint, ok bool) {
if y == 0 {
return 0, 0, false
}
q = x / y
r = x - q*y
return q, r, true
}

// dist calculates abs(x - y).
func (x fint) dist(y fint) fint {
if x > y {
Expand Down
21 changes: 21 additions & 0 deletions coefficient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ func TestFint_quo(t *testing.T) {
}
}

func TestFint_quoRem(t *testing.T) {
cases := []struct {
x, y, wantQuo, wantRem fint
wantOk bool
}{
{1, 0, 0, 0, false},
{1, 1, 1, 0, true},
{2, 4, 0, 2, true},
{20, 4, 5, 0, true},
{20, 3, 6, 2, true},
{maxFint, 2, maxFint / 2, 1, true},
}
for _, tt := range cases {
x, y := tt.x, tt.y
gotQuo, gotRem, gotOk := x.quoRem(y)
if gotQuo != tt.wantQuo || gotRem != tt.wantRem || gotOk != tt.wantOk {
t.Errorf("%v.quoRem(%v) = %v, %v, %v, want %v, %v, %v", x, y, gotQuo, gotRem, gotOk, tt.wantQuo, tt.wantRem, tt.wantOk)
}
}
}

func TestFint_dist(t *testing.T) {
cases := []struct {
x, y, wantCoef fint
Expand Down
101 changes: 90 additions & 11 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1876,33 +1876,112 @@ func (d Decimal) quoBint(e Decimal, minScale int) (Decimal, error) {
// - the divisor is 0;
// - the integer part of the quotient has more than [MaxPrec] digits.
func (d Decimal) QuoRem(e Decimal) (q, r Decimal, err error) {
q, r, err = d.quoRem(e)
// Special case: zero divisor
if e.IsZero() {
return Decimal{}, Decimal{}, fmt.Errorf("computing [%v div %v] and [%v mod %v]: %w", d, e, d, e, errDivisionByZero)
}

// General case
q, r, err = d.quoRemFint(e)
if err != nil {
return Decimal{}, Decimal{}, fmt.Errorf("computing [%v div %v] and [%v mod %v]: %w", d, e, d, e, err)
q, r, err = d.quoRemBint(e)
if err != nil {
return Decimal{}, Decimal{}, fmt.Errorf("computing [%v div %v] and [%v mod %v]: %w", d, e, d, e, err)
}
}

return q, r, nil
}

func (d Decimal) quoRem(e Decimal) (q, r Decimal, err error) {
// Quotient
q, err = d.Quo(e)
// quoRemFint computes the quotient and remainder of two decimals using uint64 arithmetic.
func (d Decimal) quoRemFint(e Decimal) (q, r Decimal, err error) {
dcoef, ecoef := d.coef, e.coef

// Alignment and rscale
var rscale int
var ok bool
switch {
case d.Scale() == e.Scale():
rscale = d.Scale()
case d.Scale() > e.Scale():
rscale = d.Scale()
ecoef, ok = ecoef.lsh(d.Scale() - e.Scale())
if !ok {
return Decimal{}, Decimal{}, errDecimalOverflow
}
case d.Scale() < e.Scale():
rscale = e.Scale()
dcoef, ok = dcoef.lsh(e.Scale() - d.Scale())
if !ok {
return Decimal{}, Decimal{}, errDecimalOverflow
}
}

// Coefficients
qcoef, rcoef, ok := dcoef.quoRem(ecoef)
if !ok {
return Decimal{}, Decimal{}, errDivisionByZero // Should never happen
}

// Signs
qsign := d.IsNeg() != e.IsNeg()
rsign := d.IsNeg()

q, err = newFromFint(qsign, qcoef, 0, 0)
if err != nil {
return Decimal{}, Decimal{}, err
}
r, err = newFromFint(rsign, rcoef, rscale, rscale)
if err != nil {
return Decimal{}, Decimal{}, err
}
return q, r, nil
}

// quoRemBint computes the quotient and remainder of two decimals using *big.Int arithmetic.
func (d Decimal) quoRemBint(e Decimal) (q, r Decimal, err error) {
dcoef := getBint()
defer putBint(dcoef)
dcoef.setFint(d.coef)

ecoef := getBint()
defer putBint(ecoef)
ecoef.setFint(e.coef)

// T-Division
q = q.Trunc(0)
qcoef := getBint()
defer putBint(qcoef)

// Reminder
r, err = e.Mul(q)
rcoef := getBint()
defer putBint(rcoef)

// Alignment and scale
var rscale int
switch {
case d.Scale() == e.Scale():
rscale = d.Scale()
case d.Scale() > e.Scale():
rscale = d.Scale()
ecoef.lsh(ecoef, d.Scale()-e.Scale())
case d.Scale() < e.Scale():
rscale = e.Scale()
dcoef.lsh(dcoef, e.Scale()-d.Scale())
}

// Coefficients
qcoef.quoRem(dcoef, ecoef, rcoef)

// Signs
qsign := d.IsNeg() != e.IsNeg()
rsign := d.IsNeg()

q, err = newFromBint(qsign, qcoef, 0, 0)
if err != nil {
return Decimal{}, Decimal{}, err
}
r, err = d.Sub(r)
r, err = newFromBint(rsign, rcoef, rscale, rscale)
if err != nil {
return Decimal{}, Decimal{}, err
}

return q, r, nil
}

Expand Down
41 changes: 17 additions & 24 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,8 @@ func TestDecimal_QuoRem(t *testing.T) {
{"4.2", "3.1000003", "1", "1.0999997"},
{"1.000000000000000000", "0.000000000000000003", "333333333333333333", "0.000000000000000001"},
{"1.000000000000000001", "0.000000000000000003", "333333333333333333", "0.000000000000000002"},
{"3", "0.9999999999999999999", "3", "0.0000000000000000003"},
{"0.9999999999999999999", "3", "0", "0.9999999999999999999"},
}
for _, tt := range tests {
d := MustParse(tt.d)
Expand All @@ -2752,7 +2754,8 @@ func TestDecimal_QuoRem(t *testing.T) {
tests := map[string]struct {
d, e string
}{
"zero 1": {"1", "0"},
"zero 1": {"1", "0"},
"overflow 1": {"9999999999999999999", "0.0000000000000000001"},
}
for _, tt := range tests {
d := MustParse(tt.d)
Expand Down Expand Up @@ -3374,7 +3377,7 @@ func FuzzDecimal_Quo(f *testing.F) {
t.Skip()
return
}
if dcoef == 0 || ecoef == 0 {
if ecoef == 0 {
t.Skip()
return
}
Expand Down Expand Up @@ -3417,23 +3420,17 @@ func FuzzDecimal_Quo(f *testing.F) {
func FuzzDecimal_QuoRem(f *testing.F) {
for _, d := range corpus {
for _, e := range corpus {
for s := 0; s <= MaxScale; s++ {
f.Add(d.neg, d.scale, d.coef, e.neg, e.scale, e.coef, s)
}
f.Add(d.neg, d.scale, d.coef, e.neg, e.scale, e.coef)
}
}

f.Fuzz(
func(t *testing.T, dneg bool, dscale int, dcoef uint64, eneg bool, escale int, ecoef uint64, scale int) {
if scale < 0 || MaxScale < scale {
t.Skip()
return
}
if dcoef == 0 || ecoef == 0 {
func(t *testing.T, dneg bool, dscale int, dcoef uint64, eneg bool, escale int, ecoef uint64) {
if ecoef == 0 {
t.Skip()
return
}
want, err := newSafe(dneg, fint(dcoef), dscale)
d, err := newSafe(dneg, fint(dcoef), dscale)
if err != nil {
t.Skip()
return
Expand All @@ -3444,29 +3441,25 @@ func FuzzDecimal_QuoRem(f *testing.F) {
return
}

q, r, err := want.QuoRem(e)
gotQ, gotR, err := d.quoRemFint(e)
if err != nil {
switch {
case errors.Is(err, errDecimalOverflow):
t.Skip() // Decimal overflow is an expected error in division
t.Skip() // Decimal overflow is an expected error in fast division
default:
t.Errorf("QuoRem(%q, %q) failed: %v", want, e, err)
t.Errorf("quoRemFint(%q, %q) failed: %v", d, e, err)
}
return
}

got, err := e.Mul(q)
wantQ, wantR, err := d.quoRemBint(e)
if err != nil {
t.Errorf("%q.Mul(%q) failed: %v", e, q, err)
t.Errorf("quoRemBint(%q, %q) failed: %v", d, e, err)
return
}
got, err = got.Add(r)
if err != nil {
t.Errorf("%q.Add(%q) failed: %v", got, r, err)
return
}
if got.Cmp(want) != 0 {
t.Errorf("%q.Mul(%q).Add(%q) = %q, want %q", e, q, r, got, want)

if gotQ.Cmp(wantQ) != 0 || gotR.Cmp(wantR) != 0 {
t.Errorf("quoRemBint(%q, %q) = (%q, %q), whereas quoRemFint(%q, %q) = (%q, %q)", d, e, wantQ, wantR, d, e, gotQ, gotR)
}
},
)
Expand Down

0 comments on commit a7ad10d

Please sign in to comment.