From 07e9cf4fc8879ac7fc4c376220ee07a38801efb0 Mon Sep 17 00:00:00 2001 From: Calvin Kim Date: Mon, 20 Dec 2021 23:47:43 +0900 Subject: [PATCH] accumulator/forest: Add AssertEqual() to Forest (#329) Refactors the test function for checking if forests equal and makes it public. AssertEqual function is useful for testing if the logical forests are equal and is good for general sanity checking. --- accumulator/forest.go | 99 +++++++++++++++++++++++++++++++++++ accumulator/forest_test.go | 104 ++++++++----------------------------- accumulator/utils.go | 22 ++++++++ 3 files changed, 142 insertions(+), 83 deletions(-) diff --git a/accumulator/forest.go b/accumulator/forest.go index ab1d6a6e..d92a936e 100644 --- a/accumulator/forest.go +++ b/accumulator/forest.go @@ -2,6 +2,7 @@ package accumulator import ( "encoding/binary" + "encoding/hex" "fmt" "os" "sort" @@ -769,3 +770,101 @@ func (f *Forest) FindLeaf(leaf Hash) bool { _, found := f.positionMap[leaf.Mini()] return found } + +// AssertEqual compares the two forests. Returns an error if the forests are not equal. +// The data meant for statics are not checked and the function will return true +// if all other fields are equal. +func (f *Forest) AssertEqual(compareForest *Forest) error { + // Return if the number of leaves are not equal. + if f.numLeaves != compareForest.numLeaves { + err := fmt.Errorf("number of leaves aren't equal"+ + "forest: %d, compared forest : %d\n", f.numLeaves, + compareForest.numLeaves) + return err + } + + // Preliminary check of the position map element count before looping + // through all the elements in the map. + if len(f.positionMap) != len(compareForest.positionMap) { + err := fmt.Errorf("position maps sizes aren't equal"+ + "forest: %d, compared forest : %d\n", len(f.positionMap), + len(compareForest.positionMap)) + return err + } + + // Make sure that the two maps are equal. + for key, val := range f.positionMap { + compVal, ok := compareForest.positionMap[key] + if !ok { + err := fmt.Errorf("miniHash %s doesn't exist in the the compared forest", + hex.EncodeToString(key[:])) + return err + } + + if val != compVal { + err := fmt.Errorf("miniHash %s returned position %d for "+ + "forest but %d for the compared forest", hex.EncodeToString(key[:]), + val, compVal) + return err + } + } + + // Each forest needs its own position tracking as they may differ in the + // actual forest rows allocated. + var fPos, compPos uint64 + + // Grab the logical rows as we're only interested in if the forests are + // logically the same. + logicalRows := logicalTreeRows(f.numLeaves) + + // Iterate through all the rows in the forest. The idea is that we'll + // keep moving up, and compare all the nodes. + for h := uint8(0); h <= logicalRows; h++ { + // We need to re-calculate the offset as we move up each row. + // This is because we allow the forest to allocate more space + // than what is actually needed. + // + // Example: In the below tree, positions that have garbage values + // are marked with '*'. This means that once we're at position 5, + // we need to go to position 8 next. This is where the need to + // re-calculate the offset comes from. + // + // 14* + // |---------------\ + // 12 13* + // |-------\ |-------\ + // 08 09 10 11* + // |---\ |---\ |---\ |---\ + // 00 01 02 03 04 05 06* 07* + // + // Grab the parent of 0 for each row that we're currently on. For + // row 1, we'll grab 08 in the above example. For row 2, we'll grab + // 12. + fPos = parentMany(0, h, f.rows) + compPos = parentMany(0, h, compareForest.rows) + + // Calculate element count in the current row. + elementCountAtRow := uint8(f.numLeaves >> h) + + // Loop through all the elements in the current row. + for i := uint8(0); i < elementCountAtRow; i++ { + // Read the hashes at the position from each of the forests. + hash := f.data.read(uint64(fPos)) + compareHash := compareForest.data.read(uint64(compPos)) + + // If the read hashes are not the same, return error. + if hash != compareHash { + err := fmt.Errorf("hashes aren't equal at forest position: %d "+ + "and compared forest position %d. "+ + "forest hash: %s compared forest hash: %s\n", + fPos, compPos, hex.EncodeToString(hash[:]), + hex.EncodeToString(compareHash[:])) + return err + } + fPos++ + compPos++ + } + } + + return nil +} diff --git a/accumulator/forest_test.go b/accumulator/forest_test.go index d843e769..c0c2b778 100644 --- a/accumulator/forest_test.go +++ b/accumulator/forest_test.go @@ -50,11 +50,23 @@ func TestForestAddDel(t *testing.T) { } func TestCowForestAddDelComp(t *testing.T) { - numAdds := uint32(1000) + // Function for writing logs. + writeLog := func(cowF, memF *Forest) { + cowstring := fmt.Sprintf("cowForest: nl %d %s\n", + cowF.numLeaves, cowF.ToString()) + fmt.Println(cowstring) + + memstring := fmt.Sprintf("memForest: nl %d %s\n", + memF.numLeaves, memF.ToString()) + fmt.Println(memstring) + } tmpDir := os.TempDir() + defer os.RemoveAll(tmpDir) + cowF := NewForest(CowForest, nil, tmpDir, 2500) memF := NewForest(RamForest, nil, "", 0) + numAdds := uint32(1000) sc := newSimChain(0x07) sc.lookahead = 400 @@ -80,93 +92,19 @@ func TestCowForestAddDelComp(t *testing.T) { t.Fatal(err) } if b%100 == 0 { - equal, wrongPos, wrongPosH := checkIfEqual(cowF, memF) - if !equal { - cowFile, err := os.OpenFile("cowlog", - os.O_CREATE|os.O_RDWR, 666) - if err != nil { - panic(err) - } - cowstring := fmt.Sprintf("nl %d %s\n", cowF.numLeaves, cowF.ToString()) - cowFile.WriteString(cowstring) - - memFile, err := os.OpenFile("memlog", - os.O_CREATE|os.O_RDWR, 666) - if err != nil { - panic(err) - } - - memstring := fmt.Sprintf("nl %d %s\n", memF.numLeaves, memF.ToString()) - memFile.WriteString(memstring) - s := fmt.Sprintf("forests are not equal\n") - s += fmt.Sprintf("forestRows in f: %d\n: ", cowF.rows) - s += fmt.Sprintf("wrongPos: %x\n", wrongPos) - s += fmt.Sprintf("wrongPosH %x\n", wrongPosH) - t.Fatal(s) + err := cowF.AssertEqual(memF) + if err != nil { + writeLog(cowF, memF) + t.Fatal(err) } } } - equal, wrongPos, wrongPosH := checkIfEqual(cowF, memF) - if !equal { - cowFile, err := os.OpenFile("cowlog", - os.O_CREATE|os.O_RDWR, 666) - if err != nil { - panic(err) - } - cowstring := fmt.Sprintf("nl %d %s\n", cowF.numLeaves, cowF.ToString()) - cowFile.WriteString(cowstring) - - memFile, err := os.OpenFile("memlog", - os.O_CREATE|os.O_RDWR, 666) - if err != nil { - panic(err) - } - - memstring := fmt.Sprintf("nl %d %s\n", memF.numLeaves, memF.ToString()) - memFile.WriteString(memstring) - s := fmt.Sprintf("forests are not equal\n") - s += fmt.Sprintf("forestRows in f: %d\n: ", cowF.rows) - s += fmt.Sprintf("wrongPos: %x\n", wrongPos) - s += fmt.Sprintf("wrongPosH %x\n", wrongPosH) - t.Fatal(s) - } -} - -// checkIfEqual checks if the forest differ returns true for equal and if not, returns -// the positions and the hashes -func checkIfEqual(cowF, memF *Forest) (bool, []uint64, []Hash) { - cowFH := cowF.rows - memFH := memF.rows - - if cowFH != memFH { - panic("forestRows don't equal") - } - - var pos uint8 - for h := uint8(0); h <= memFH; h++ { - rowlen := uint8(1 << (memFH - h)) - - for j := uint8(0); j < rowlen; j++ { - if cowF.data.size() != memF.data.size() { - s := fmt.Sprintf("sizes don't equal"+ - "cow: %d, mem: %d\n", cowF.data.size(), memF.data.size()) - panic(s) - } - ok := memF.data.size() >= uint64(pos) - if ok { - memH := memF.data.read(uint64(pos)) - cowH := cowF.data.read(uint64(pos)) - if memH != cowH { - s := fmt.Sprintf("hashes aren't equal at gpos: %d "+"mem: %x cow: %x\n", pos, memH, cowH) - panic(s) - } - } - pos++ - } + err := cowF.AssertEqual(memF) + if err != nil { + writeLog(cowF, memF) + t.Fatal(err) } - - return true, []uint64{}, []Hash{} } func TestCowForestAddDel(t *testing.T) { diff --git a/accumulator/utils.go b/accumulator/utils.go index 1cf06372..e9669f9b 100644 --- a/accumulator/utils.go +++ b/accumulator/utils.go @@ -307,6 +307,14 @@ func inForest(pos, numLeaves uint64, forestRows uint8) bool { } // treeRows returns the number of rows given n leaves. +// Example: The below tree will return 2 as the forest will allocate enough for +// 4 leaves. +// +// row 2: +// |-------\ +// row 1: 04 +// |---\ |---\ +// row 0: 00 01 02 func treeRows(n uint64) uint8 { // treeRows works by: // 1. Find the next power of 2 from the given n leaves. @@ -338,6 +346,20 @@ func treeRows(n uint64) uint8 { } +// logicalTreeRows returns the number of +// +// Example: The below tree will return 1 as the logical number of rows is 1 for this +// forest. +// +// row 2: +// |-------\ +// row 1: 04 +// |---\ |---\ +// row 0: 00 01 02 +func logicalTreeRows(n uint64) uint8 { + return uint8(bits.Len64(n) - 1) +} + // numRoots returns the number of 1 bits in n. func numRoots(n uint64) uint8 { return uint8(bits.OnesCount64(n))