Skip to content

Commit

Permalink
accumulator/forest: Add AssertEqual() to Forest (#329)
Browse files Browse the repository at this point in the history
Refactors the test function for checking if forests equal and makes it
public. AssertEqual function is useful for testing if the logical
forests are equal and is good for general sanity checking.
  • Loading branch information
kcalvinalvin authored Dec 20, 2021
1 parent cc944c1 commit 07e9cf4
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 83 deletions.
99 changes: 99 additions & 0 deletions accumulator/forest.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package accumulator

import (
"encoding/binary"
"encoding/hex"
"fmt"
"os"
"sort"
Expand Down Expand Up @@ -769,3 +770,101 @@ func (f *Forest) FindLeaf(leaf Hash) bool {
_, found := f.positionMap[leaf.Mini()]
return found
}

// AssertEqual compares the two forests. Returns an error if the forests are not equal.
// The data meant for statics are not checked and the function will return true
// if all other fields are equal.
func (f *Forest) AssertEqual(compareForest *Forest) error {
// Return if the number of leaves are not equal.
if f.numLeaves != compareForest.numLeaves {
err := fmt.Errorf("number of leaves aren't equal"+
"forest: %d, compared forest : %d\n", f.numLeaves,
compareForest.numLeaves)
return err
}

// Preliminary check of the position map element count before looping
// through all the elements in the map.
if len(f.positionMap) != len(compareForest.positionMap) {
err := fmt.Errorf("position maps sizes aren't equal"+
"forest: %d, compared forest : %d\n", len(f.positionMap),
len(compareForest.positionMap))
return err
}

// Make sure that the two maps are equal.
for key, val := range f.positionMap {
compVal, ok := compareForest.positionMap[key]
if !ok {
err := fmt.Errorf("miniHash %s doesn't exist in the the compared forest",
hex.EncodeToString(key[:]))
return err
}

if val != compVal {
err := fmt.Errorf("miniHash %s returned position %d for "+
"forest but %d for the compared forest", hex.EncodeToString(key[:]),
val, compVal)
return err
}
}

// Each forest needs its own position tracking as they may differ in the
// actual forest rows allocated.
var fPos, compPos uint64

// Grab the logical rows as we're only interested in if the forests are
// logically the same.
logicalRows := logicalTreeRows(f.numLeaves)

// Iterate through all the rows in the forest. The idea is that we'll
// keep moving up, and compare all the nodes.
for h := uint8(0); h <= logicalRows; h++ {
// We need to re-calculate the offset as we move up each row.
// This is because we allow the forest to allocate more space
// than what is actually needed.
//
// Example: In the below tree, positions that have garbage values
// are marked with '*'. This means that once we're at position 5,
// we need to go to position 8 next. This is where the need to
// re-calculate the offset comes from.
//
// 14*
// |---------------\
// 12 13*
// |-------\ |-------\
// 08 09 10 11*
// |---\ |---\ |---\ |---\
// 00 01 02 03 04 05 06* 07*
//
// Grab the parent of 0 for each row that we're currently on. For
// row 1, we'll grab 08 in the above example. For row 2, we'll grab
// 12.
fPos = parentMany(0, h, f.rows)
compPos = parentMany(0, h, compareForest.rows)

// Calculate element count in the current row.
elementCountAtRow := uint8(f.numLeaves >> h)

// Loop through all the elements in the current row.
for i := uint8(0); i < elementCountAtRow; i++ {
// Read the hashes at the position from each of the forests.
hash := f.data.read(uint64(fPos))
compareHash := compareForest.data.read(uint64(compPos))

// If the read hashes are not the same, return error.
if hash != compareHash {
err := fmt.Errorf("hashes aren't equal at forest position: %d "+
"and compared forest position %d. "+
"forest hash: %s compared forest hash: %s\n",
fPos, compPos, hex.EncodeToString(hash[:]),
hex.EncodeToString(compareHash[:]))
return err
}
fPos++
compPos++
}
}

return nil
}
104 changes: 21 additions & 83 deletions accumulator/forest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,23 @@ func TestForestAddDel(t *testing.T) {
}

func TestCowForestAddDelComp(t *testing.T) {
numAdds := uint32(1000)
// Function for writing logs.
writeLog := func(cowF, memF *Forest) {
cowstring := fmt.Sprintf("cowForest: nl %d %s\n",
cowF.numLeaves, cowF.ToString())
fmt.Println(cowstring)

memstring := fmt.Sprintf("memForest: nl %d %s\n",
memF.numLeaves, memF.ToString())
fmt.Println(memstring)
}

tmpDir := os.TempDir()
defer os.RemoveAll(tmpDir)

cowF := NewForest(CowForest, nil, tmpDir, 2500)
memF := NewForest(RamForest, nil, "", 0)
numAdds := uint32(1000)

sc := newSimChain(0x07)
sc.lookahead = 400
Expand All @@ -80,93 +92,19 @@ func TestCowForestAddDelComp(t *testing.T) {
t.Fatal(err)
}
if b%100 == 0 {
equal, wrongPos, wrongPosH := checkIfEqual(cowF, memF)
if !equal {
cowFile, err := os.OpenFile("cowlog",
os.O_CREATE|os.O_RDWR, 666)
if err != nil {
panic(err)
}
cowstring := fmt.Sprintf("nl %d %s\n", cowF.numLeaves, cowF.ToString())
cowFile.WriteString(cowstring)

memFile, err := os.OpenFile("memlog",
os.O_CREATE|os.O_RDWR, 666)
if err != nil {
panic(err)
}

memstring := fmt.Sprintf("nl %d %s\n", memF.numLeaves, memF.ToString())
memFile.WriteString(memstring)
s := fmt.Sprintf("forests are not equal\n")
s += fmt.Sprintf("forestRows in f: %d\n: ", cowF.rows)
s += fmt.Sprintf("wrongPos: %x\n", wrongPos)
s += fmt.Sprintf("wrongPosH %x\n", wrongPosH)
t.Fatal(s)
err := cowF.AssertEqual(memF)
if err != nil {
writeLog(cowF, memF)
t.Fatal(err)
}
}
}
equal, wrongPos, wrongPosH := checkIfEqual(cowF, memF)
if !equal {
cowFile, err := os.OpenFile("cowlog",
os.O_CREATE|os.O_RDWR, 666)
if err != nil {
panic(err)
}
cowstring := fmt.Sprintf("nl %d %s\n", cowF.numLeaves, cowF.ToString())
cowFile.WriteString(cowstring)

memFile, err := os.OpenFile("memlog",
os.O_CREATE|os.O_RDWR, 666)
if err != nil {
panic(err)
}

memstring := fmt.Sprintf("nl %d %s\n", memF.numLeaves, memF.ToString())
memFile.WriteString(memstring)
s := fmt.Sprintf("forests are not equal\n")
s += fmt.Sprintf("forestRows in f: %d\n: ", cowF.rows)
s += fmt.Sprintf("wrongPos: %x\n", wrongPos)
s += fmt.Sprintf("wrongPosH %x\n", wrongPosH)
t.Fatal(s)
}
}

// checkIfEqual checks if the forest differ returns true for equal and if not, returns
// the positions and the hashes
func checkIfEqual(cowF, memF *Forest) (bool, []uint64, []Hash) {
cowFH := cowF.rows
memFH := memF.rows

if cowFH != memFH {
panic("forestRows don't equal")
}

var pos uint8
for h := uint8(0); h <= memFH; h++ {
rowlen := uint8(1 << (memFH - h))

for j := uint8(0); j < rowlen; j++ {
if cowF.data.size() != memF.data.size() {
s := fmt.Sprintf("sizes don't equal"+
"cow: %d, mem: %d\n", cowF.data.size(), memF.data.size())
panic(s)
}
ok := memF.data.size() >= uint64(pos)
if ok {
memH := memF.data.read(uint64(pos))
cowH := cowF.data.read(uint64(pos))
if memH != cowH {
s := fmt.Sprintf("hashes aren't equal at gpos: %d "+"mem: %x cow: %x\n", pos, memH, cowH)
panic(s)
}
}
pos++
}

err := cowF.AssertEqual(memF)
if err != nil {
writeLog(cowF, memF)
t.Fatal(err)
}

return true, []uint64{}, []Hash{}
}

func TestCowForestAddDel(t *testing.T) {
Expand Down
22 changes: 22 additions & 0 deletions accumulator/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ func inForest(pos, numLeaves uint64, forestRows uint8) bool {
}

// treeRows returns the number of rows given n leaves.
// Example: The below tree will return 2 as the forest will allocate enough for
// 4 leaves.
//
// row 2:
// |-------\
// row 1: 04
// |---\ |---\
// row 0: 00 01 02
func treeRows(n uint64) uint8 {
// treeRows works by:
// 1. Find the next power of 2 from the given n leaves.
Expand Down Expand Up @@ -338,6 +346,20 @@ func treeRows(n uint64) uint8 {

}

// logicalTreeRows returns the number of
//
// Example: The below tree will return 1 as the logical number of rows is 1 for this
// forest.
//
// row 2:
// |-------\
// row 1: 04
// |---\ |---\
// row 0: 00 01 02
func logicalTreeRows(n uint64) uint8 {
return uint8(bits.Len64(n) - 1)
}

// numRoots returns the number of 1 bits in n.
func numRoots(n uint64) uint8 {
return uint8(bits.OnesCount64(n))
Expand Down

0 comments on commit 07e9cf4

Please sign in to comment.