diff --git a/bmr/network.go b/bmr/network.go index 42bf71b1..59f361b8 100644 --- a/bmr/network.go +++ b/bmr/network.go @@ -14,5 +14,6 @@ type Operand byte // Network protocol messages. const ( OpInit Operand = iota - OpFx + OpFxLambda + OpFxR ) diff --git a/bmr/operand_string.go b/bmr/operand_string.go index 4c870c94..e5a4b170 100644 --- a/bmr/operand_string.go +++ b/bmr/operand_string.go @@ -9,12 +9,13 @@ func _() { // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[OpInit-0] - _ = x[OpFx-1] + _ = x[OpFxLambda-1] + _ = x[OpFxR-2] } -const _Operand_name = "InitFx" +const _Operand_name = "InitFxLambdaFxR" -var _Operand_index = [...]uint8{0, 4, 6} +var _Operand_index = [...]uint8{0, 4, 12, 15} func (i Operand) String() string { if i >= Operand(len(_Operand_index)-1) { diff --git a/bmr/peer.go b/bmr/peer.go index f2bac478..c0a7bd6c 100644 --- a/bmr/peer.go +++ b/bmr/peer.go @@ -9,7 +9,9 @@ package bmr import ( "fmt" "io" + "math/big" + "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/ot" "github.com/markkurossi/text/superscript" ) @@ -83,12 +85,12 @@ func (peer *Peer) consumerMsgLoop(id string) error { if err != nil { return err } - case OpFx: + case OpFxLambda: gid, err := peer.from.ReceiveUint32() if err != nil { return err } - peer.this.Debugf("%s: %s: id=%v\n", id, op, gid) + peer.this.Debugf("%s: %s: gid=%v\n", id, op, gid) gate := peer.this.circ.Gates[gid] lv := peer.this.lambda.Bit(int(gate.Input1)) @@ -101,13 +103,50 @@ func (peer *Peer) consumerMsgLoop(id string) error { v ^= xb peer.this.luv.SetBit(peer.this.luv, gid, v) peer.this.completions++ - if peer.this.completions == peer.this.circ.NumGates { + if peer.this.completions == peer.this.syncBarrier(1) { + peer.this.c.Signal() + } + peer.this.m.Unlock() + + case OpFxR: + gid, err := peer.from.ReceiveUint32() + if err != nil { + return err + } + peer.this.Debugf("%s: %s: gid=%v\n", id, op, gid) + gate := peer.this.circ.Gates[gid] + luvws := []*big.Int{ + peer.this.luvw0, + peer.this.luvw1, + peer.this.luvw2, + peer.this.luvw3, + } + // XXX patch luvws based on gate.Op + switch gate.Op { + case circuit.AND: + default: + return fmt.Errorf("gate %v not implemented yet", gate.Op) + } + var xbs []Label + for _, luvw := range luvws { + xb, err := FxkReceive(peer.otReceiver, luvw.Bit(gid)) + if err != nil { + return err + } + xbs = append(xbs, xb) + } + peer.this.m.Lock() + for i := 0; i < len(xbs); i++ { + peer.this.rj[gid][i].Xor(xbs[i]) + } + peer.this.completions++ + if peer.this.completions == peer.this.syncBarrier(2) { peer.this.c.Signal() } peer.this.m.Unlock() default: - return fmt.Errorf("%s: %s: not implemented\n", id, op) + return fmt.Errorf("%s: %s: not implemented", id, op) } peer.to.Flush() diff --git a/bmr/player.go b/bmr/player.go index 4cc076bd..7fd14764 100644 --- a/bmr/player.go +++ b/bmr/player.go @@ -43,6 +43,9 @@ type Player struct { luvw1 *big.Int luvw2 *big.Int luvw3 *big.Int + + // The XOR shares of Rj matching ρij,α,β + rj [][]Label } // NewPlayer creates a new multi-party player. @@ -99,6 +102,17 @@ func (p *Player) Play() error { count, p.numPlayers-1) } + // Init circuit-dependent fields. + p.rj = make([][]Label, p.circ.NumGates) + for i := 0; i < p.circ.NumGates; i++ { + switch p.circ.Gates[i].Op { + case circuit.AND: + p.rj[i] = make([]Label, 4) + default: + return fmt.Errorf("gate %v not implemented yet", p.circ.Gates[i].Op) + } + } + p.Debugf("BMR: #gates=%v\n", p.circ.NumGates) p.Debugf("Offline Phase...\n") @@ -216,6 +230,10 @@ func (p *Player) offlinePhase() error { return nil } +func (p *Player) syncBarrier(nth int) int { + return p.circ.NumGates * (len(p.peers) - 1) * nth +} + // fgc computes the multiparty garbled circuit (3.1.2 The Protocol for // Fgc - Page 7). func (p *Player) fgc() (err error) { @@ -238,7 +256,7 @@ func (p *Player) fgc() (err error) { if peer == nil { continue } - err = peer.to.SendByte(byte(OpFx)) + err = peer.to.SendByte(byte(OpFxLambda)) if err != nil { return err } @@ -261,8 +279,7 @@ func (p *Player) fgc() (err error) { p.m.Lock() p.luv.Xor(p.luv, luv) - - for p.completions < p.circ.NumGates { + for p.completions < p.syncBarrier(1) { p.c.Wait() } p.m.Unlock() @@ -297,7 +314,6 @@ func (p *Player) fgc() (err error) { return fmt.Errorf("gate %v not implemented yet", gate.Op) } } - fmt.Printf("Player%s: %cuvw=%v\n", p.IDString(), symbols.Lambda, lambda(p.luvw0, p.circ.NumGates)) fmt.Printf("Player%s: %cuv̄w=%v\n", p.IDString(), symbols.Lambda, @@ -307,6 +323,48 @@ func (p *Player) fgc() (err error) { fmt.Printf("Player%s: %cūv̄w=%v\n", p.IDString(), symbols.Lambda, lambda(p.luvw3, p.circ.NumGates)) + // Step 3: for i!=j, run Fxk(R,luvw) + for gid := 0; gid < p.circ.NumGates; gid++ { + for _, peer := range p.peers { + if peer == nil { + continue + } + err = peer.to.SendByte(byte(OpFxR)) + if err != nil { + return err + } + err = peer.to.SendUint32(gid) + if err != nil { + return err + } + for n := 0; n < len(p.rj[gid]); n++ { + r, err := FxkSend(peer.otSender, p.r) + if err != nil { + return err + } + p.m.Lock() + p.rj[gid][n].Xor(r) + p.m.Unlock() + } + } + } + p.m.Lock() + for p.completions < p.syncBarrier(2) { + p.c.Wait() + } + p.m.Unlock() + + for gid := 0; gid < p.circ.NumGates; gid++ { + p.Debugf("Player%s: rj[%v]:\t", p.IDString(), gid) + for idx, l := range p.rj[gid] { + if idx > 0 { + p.Debugf(" ") + } + p.Debugf("%v", l) + } + p.Debugf("\n") + } + return nil }