From c3085ba3dfd89ad1ec4e5eabedbff31624864d53 Mon Sep 17 00:00:00 2001 From: Markku Rossi Date: Sun, 19 Feb 2023 11:52:03 +0100 Subject: [PATCH] Oblivious Transfer with the OT interface. Changed the default OT to CO OT. --- apps/garbled/main.go | 23 +++++++----- apps/garbled/streaming.go | 14 +++++--- benchmarks.md | 32 +++++++++++++++++ circuit/evaluator.go | 45 +++++++++++------------- circuit/garbler.go | 64 +++++++-------------------------- circuit/stream_evaluator.go | 42 +++++++++------------- circuit/stream_garble.go | 7 +++- compiler/arithmetic_test.go | 17 +++++---- compiler/compiler.go | 13 +++---- compiler/ssa/streamer.go | 70 ++++++------------------------------- 10 files changed, 138 insertions(+), 189 deletions(-) diff --git a/apps/garbled/main.go b/apps/garbled/main.go index ef3219e9..517f08a7 100644 --- a/apps/garbled/main.go +++ b/apps/garbled/main.go @@ -1,7 +1,7 @@ // // main.go // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -25,6 +25,7 @@ import ( "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler" "github.com/markkurossi/mpc/compiler/utils" + "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" ) @@ -130,11 +131,14 @@ func main() { return } + //oti := ot.NewRSA(2048) + oti := ot.NewCO() + if *stream { if *evaluator { - err = streamEvaluatorMode(params, inputFlag, len(*cpuprofile) > 0) + err = streamEvaluatorMode(oti, inputFlag, len(*cpuprofile) > 0) } else { - err = streamGarblerMode(params, inputFlag, flag.Args()) + err = streamGarblerMode(params, oti, inputFlag, flag.Args()) } memProfile(*memprofile) if err != nil { @@ -289,14 +293,14 @@ func main() { fmt.Printf("%s\n", err) os.Exit(1) } - err = evaluatorMode(circ, input, len(*cpuprofile) > 0) + err = evaluatorMode(oti, circ, input, len(*cpuprofile) > 0) } else { input, err = circ.Inputs[0].Parse(inputFlag) if err != nil { fmt.Printf("%s\n", err) os.Exit(1) } - err = garblerMode(circ, input) + err = garblerMode(oti, circ, input) } if err != nil { log.Fatal(err) @@ -321,7 +325,8 @@ func memProfile(file string) { } } -func evaluatorMode(circ *circuit.Circuit, input *big.Int, once bool) error { +func evaluatorMode(oti ot.OT, circ *circuit.Circuit, input *big.Int, + once bool) error { ln, err := net.Listen("tcp", port) if err != nil { return err @@ -336,7 +341,7 @@ func evaluatorMode(circ *circuit.Circuit, input *big.Int, once bool) error { fmt.Printf("New connection from %s\n", nc.RemoteAddr()) conn := p2p.NewConn(nc) - result, err := circuit.Evaluator(conn, circ, input, verbose) + result, err := circuit.Evaluator(conn, oti, circ, input, verbose) conn.Close() if err != nil && err != io.EOF { @@ -350,7 +355,7 @@ func evaluatorMode(circ *circuit.Circuit, input *big.Int, once bool) error { } } -func garblerMode(circ *circuit.Circuit, input *big.Int) error { +func garblerMode(oti ot.OT, circ *circuit.Circuit, input *big.Int) error { nc, err := net.Dial("tcp", port) if err != nil { return err @@ -358,7 +363,7 @@ func garblerMode(circ *circuit.Circuit, input *big.Int) error { conn := p2p.NewConn(nc) defer conn.Close() - result, err := circuit.Garbler(conn, circ, input, verbose) + result, err := circuit.Garbler(conn, oti, circ, input, verbose) if err != nil { return err } diff --git a/apps/garbled/streaming.go b/apps/garbled/streaming.go index 4b74ffc1..9cee6332 100644 --- a/apps/garbled/streaming.go +++ b/apps/garbled/streaming.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -15,10 +15,11 @@ import ( "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler" "github.com/markkurossi/mpc/compiler/utils" + "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" ) -func streamEvaluatorMode(params *utils.Params, input input, once bool) error { +func streamEvaluatorMode(oti ot.OT, input input, once bool) error { ln, err := net.Listen("tcp", port) if err != nil { return err @@ -33,7 +34,8 @@ func streamEvaluatorMode(params *utils.Params, input input, once bool) error { fmt.Printf("New connection from %s\n", nc.RemoteAddr()) conn := p2p.NewConn(nc) - outputs, result, err := circuit.StreamEvaluator(conn, input, verbose) + outputs, result, err := circuit.StreamEvaluator(conn, oti, input, + verbose) conn.Close() if err != nil && err != io.EOF { @@ -47,7 +49,9 @@ func streamEvaluatorMode(params *utils.Params, input input, once bool) error { } } -func streamGarblerMode(params *utils.Params, input input, args []string) error { +func streamGarblerMode(params *utils.Params, oti ot.OT, input input, + args []string) error { + if len(args) != 1 || !strings.HasSuffix(args[0], ".mpcl") { return fmt.Errorf("streaming mode takes single MPCL file") } @@ -59,7 +63,7 @@ func streamGarblerMode(params *utils.Params, input input, args []string) error { defer conn.Close() outputs, result, err := compiler.New(params).StreamFile( - conn, args[0], input) + conn, oti, args[0], input) if err != nil { return err } diff --git a/benchmarks.md b/benchmarks.md index 62db5e10..5e33ce54 100644 --- a/benchmarks.md +++ b/benchmarks.md @@ -208,6 +208,22 @@ Circuit: #gates=5539148 (XOR=3996414 XNOR=48825 AND=1493909 OR=0 INV=0) #w=55393 └────────┴──────────────┴────────┴──────┘ ``` +CO OT: + +``` +Circuit: #gates=5539117 (XOR=3996381 XNOR=48824 AND=1493910 OR=1 INV=1 xor=4045205 !xor=1493912 levels=1604812 width=8259) #w=5539277 +┌────────┬──────────────┬────────┬──────┐ +│ Op │ Time │ % │ Xfer │ +├────────┼──────────────┼────────┼──────┤ +│ Wait │ 544.037201ms │ 69.10% │ │ +│ Recv │ 119.178002ms │ 15.14% │ 69MB │ +│ Inputs │ 5.523428ms │ 0.70% │ 3kB │ +│ Eval │ 118.35074ms │ 15.03% │ │ +│ Result │ 241.858µs │ 0.03% │ 1kB │ +│ Total │ 787.331229ms │ │ 69MB │ +└────────┴──────────────┴────────┴──────┘ +``` + ## Ed25519 signature computation The first signature computation without SHA-512: @@ -453,6 +469,22 @@ Parallel garbling/write: └─────────────┴─────────────────┴────────┴───────┘ ``` +CO OT: + +``` +┌─────────────┬─────────────────┬────────┬──────┐ +│ Op │ Time │ % │ Xfer │ +├─────────────┼─────────────────┼────────┼──────┤ +│ Compile │ 2.192778168s │ 2.82% │ │ +│ Init │ 2.127053ms │ 0.00% │ 0B │ +│ OT Init │ 10.404µs │ 0.00% │ 0B │ +│ Peer Inputs │ 86.187868ms │ 0.11% │ 74kB │ +│ Garble │ 1m15.571077934s │ 97.07% │ 15GB │ +│ Result │ 338.085µs │ 0.00% │ 55kB │ +│ Total │ 1m17.852519512s │ │ 15GB │ +└─────────────┴─────────────────┴────────┴──────┘ +``` + ## RSA signature computation diff --git a/circuit/evaluator.go b/circuit/evaluator.go index b4ee4ef5..818e29a6 100644 --- a/circuit/evaluator.go +++ b/circuit/evaluator.go @@ -9,7 +9,6 @@ package circuit import ( - "crypto/rsa" "fmt" "math/big" @@ -22,8 +21,8 @@ var ( ) // Evaluator runs the evaluator on the P2P network. -func Evaluator(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( - []*big.Int, error) { +func Evaluator(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, + verbose bool) ([]*big.Int, error) { timing := NewTiming() @@ -82,19 +81,7 @@ func Evaluator(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( } // Init oblivious transfer. - pubN, err := conn.ReceiveData() - if err != nil { - return nil, err - } - pubE, err := conn.ReceiveUint32() - if err != nil { - return nil, err - } - pub := &rsa.PublicKey{ - N: big.NewInt(0).SetBytes(pubN), - E: pubE, - } - receiver, err := ot.NewReceiver(pub) + err = oti.InitReceiver(conn) if err != nil { return nil, err } @@ -105,16 +92,26 @@ func Evaluator(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( if verbose { fmt.Printf(" - Querying our inputs...\n") } + if err := conn.SendUint32(OpOT); err != nil { + return nil, err + } + // Wire offset. + if err := conn.SendUint32(circ.Inputs[0].Size); err != nil { + return nil, err + } + // Wire count. + if err := conn.SendUint32(circ.Inputs[1].Size); err != nil { + return nil, err + } + conn.Flush() + flags := make([]bool, circ.Inputs[1].Size) for i := 0; i < circ.Inputs[1].Size; i++ { - if err := conn.SendUint32(OpOT); err != nil { - return nil, err + if inputs.Bit(i) == 1 { + flags[i] = true } - n, err := conn.Receive(receiver, uint(circ.Inputs[0].Size+i), - inputs.Bit(i)) - if err != nil { - return nil, err - } - wires[Wire(circ.Inputs[0].Size+i)].SetBytes(n) + } + if err := oti.Receive(flags, wires[circ.Inputs[0].Size:]); err != nil { + return nil, err } xfer := conn.Stats.Sub(ioStats) ioStats = conn.Stats diff --git a/circuit/garbler.go b/circuit/garbler.go index 7dd09e36..8d8a87ed 100644 --- a/circuit/garbler.go +++ b/circuit/garbler.go @@ -1,7 +1,7 @@ // // garbler.go // -// Copyright (c) 2019-2021 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -44,8 +44,8 @@ func (s FileSize) String() string { } // Garbler runs the garbler on the P2P network. -func Garbler(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( - []*big.Int, error) { +func Garbler(conn *p2p.Conn, oti ot.OT, circ *Circuit, inputs *big.Int, + verbose bool) ([]*big.Int, error) { timing := NewTiming() if verbose { @@ -120,22 +120,10 @@ func Garbler(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( } // Init oblivious transfer. - sender, err := ot.NewSender(2048) + err = oti.InitSender(conn) if err != nil { return nil, err } - - // Send our public key. - pub := sender.PublicKey() - data := pub.N.Bytes() - if err := conn.SendData(data); err != nil { - return nil, err - } - if err := conn.SendUint32(pub.E); err != nil { - return nil, err - } - conn.Flush() - ioStats = conn.Stats.Sub(ioStats) timing.Sample("OT Init", []string{FileSize(ioStats.Sum()).String()}) @@ -161,52 +149,24 @@ func Garbler(conn *p2p.Conn, circ *Circuit, inputs *big.Int, verbose bool) ( switch op { case OpOT: - bit, err := conn.ReceiveUint32() + offset, err := conn.ReceiveUint32() if err != nil { return nil, err } - if !allowedOTs[bit] { - return nil, fmt.Errorf("peer can't OT wire %d", bit) - } - allowedOTs[bit] = false - - wire := garbled.Wires[bit] - - var m0Buf, m1Buf ot.LabelData - m0Data := wire.L0.Bytes(&m0Buf) - m1Data := wire.L1.Bytes(&m1Buf) - - xfer, err := sender.NewTransfer(m0Data, m1Data) + count, err := conn.ReceiveUint32() if err != nil { return nil, err } - - x0, x1 := xfer.RandomMessages() - if err := conn.SendData(x0); err != nil { - return nil, err - } - if err := conn.SendData(x1); err != nil { - return nil, err + for i := 0; i < count; i++ { + if !allowedOTs[offset+i] { + return nil, fmt.Errorf("peer can't OT wire %d", offset+i) + } + allowedOTs[offset+i] = false } - conn.Flush() - - v, err := conn.ReceiveData() + err = oti.Send(garbled.Wires[offset : offset+count]) if err != nil { return nil, err } - xfer.ReceiveV(v) - - m0p, m1p, err := xfer.Messages() - if err != nil { - return nil, err - } - if err := conn.SendData(m0p); err != nil { - return nil, err - } - if err := conn.SendData(m1p); err != nil { - return nil, err - } - conn.Flush() lastOT = time.Now() case OpResult: diff --git a/circuit/stream_evaluator.go b/circuit/stream_evaluator.go index 43e0702d..de3bcff2 100644 --- a/circuit/stream_evaluator.go +++ b/circuit/stream_evaluator.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -9,7 +9,6 @@ package circuit import ( "crypto/aes" "crypto/cipher" - "crypto/rsa" "fmt" "math/big" "time" @@ -47,6 +46,11 @@ func (stream *StreamEval) Get(tmp bool, w int) ot.Label { return stream.wires[w] } +// GetInputs gets the specified input wire range. +func (stream *StreamEval) GetInputs(offset, count int) []ot.Label { + return stream.wires[offset : offset+count] +} + // Set sets the value of the wire. func (stream *StreamEval) Set(tmp bool, w int, label ot.Label) { if tmp { @@ -75,8 +79,8 @@ func (stream *StreamEval) InitCircuit(numWires, numTmpWires int) { } // StreamEvaluator runs the stream evaluator on the connection. -func StreamEvaluator(conn *p2p.Conn, inputFlag []string, verbose bool) ( - IO, []*big.Int, error) { +func StreamEvaluator(conn *p2p.Conn, oti ot.OT, inputFlag []string, + verbose bool) (IO, []*big.Int, error) { timing := NewTiming() @@ -147,23 +151,10 @@ func StreamEvaluator(conn *p2p.Conn, inputFlag []string, verbose bool) ( } // Init oblivious transfer. - pubN, err := conn.ReceiveData() - if err != nil { - return nil, nil, err - } - pubE, err := conn.ReceiveUint32() - if err != nil { - return nil, nil, err - } - pub := &rsa.PublicKey{ - N: big.NewInt(0).SetBytes(pubN), - E: pubE, - } - receiver, err := ot.NewReceiver(pub) + err = oti.InitReceiver(conn) if err != nil { return nil, nil, err } - ioStats := conn.Stats timing.Sample("Init", []string{FileSize(ioStats.Sum()).String()}) @@ -171,14 +162,15 @@ func StreamEvaluator(conn *p2p.Conn, inputFlag []string, verbose bool) ( if verbose { fmt.Printf(" - Querying our inputs...\n") } - for w := 0; w < in2.Size; w++ { - n, err := conn.Receive(receiver, uint(in1.Size+w), inputs.Bit(w)) - if err != nil { - return nil, nil, err + flags := make([]bool, in2.Size) + for i := 0; i < in2.Size; i++ { + if inputs.Bit(i) == 1 { + flags[i] = true } - var label ot.Label - label.SetBytes(n) - streaming.Set(false, in1.Size+w, label) + } + inputLabels := streaming.GetInputs(in1.Size, in2.Size) + if err := oti.Receive(flags, inputLabels); err != nil { + return nil, nil, err } xfer := conn.Stats.Sub(ioStats) ioStats = conn.Stats diff --git a/circuit/stream_garble.go b/circuit/stream_garble.go index fbcd561c..a7a435ef 100644 --- a/circuit/stream_garble.go +++ b/circuit/stream_garble.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2021 Markku Rossi +// Copyright (c) 2020-2021, 2023 Markku Rossi // // All rights reserved. // @@ -110,6 +110,11 @@ func (stream *Streaming) GetInput(w Wire) ot.Wire { return stream.wires[w] } +// GetInputs gets the specified input wire range. +func (stream *Streaming) GetInputs(offset, count int) []ot.Wire { + return stream.wires[offset : offset+count] +} + // Get gets the value of the wire. func (stream *Streaming) Get(w Wire) (ot.Wire, Wire, bool) { if w < stream.firstTmp { diff --git a/compiler/arithmetic_test.go b/compiler/arithmetic_test.go index 6e034c33..35b6d12b 100644 --- a/compiler/arithmetic_test.go +++ b/compiler/arithmetic_test.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2021 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -14,6 +14,7 @@ import ( "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler/utils" + "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" ) @@ -125,13 +126,13 @@ func TestArithmetics(t *testing.T) { gerr := make(chan error) go func() { - _, err := circuit.Garbler(p2p.NewConn(gio), circ, - gInput, false) + _, err := circuit.Garbler(p2p.NewConn(gio), ot.NewCO(), + circ, gInput, false) gerr <- err }() - result, err := circuit.Evaluator(p2p.NewConn(eio), circ, - eInput, false) + result, err := circuit.Evaluator(p2p.NewConn(eio), + ot.NewCO(), circ, eInput, false) if err != nil { t.Fatalf("Evaluator failed: %s\n", err) } @@ -178,11 +179,13 @@ func BenchmarkMult(b *testing.B) { gerr := make(chan error) go func() { - _, err := circuit.Garbler(p2p.NewConn(gio), circ, gInput, false) + _, err := circuit.Garbler(p2p.NewConn(gio), ot.NewCO(), circ, gInput, + false) gerr <- err }() - _, err = circuit.Evaluator(p2p.NewConn(eio), circ, eInput, false) + _, err = circuit.Evaluator(p2p.NewConn(eio), ot.NewCO(), circ, eInput, + false) if err != nil { b.Fatalf("Evaluator failed: %s\n", err) } diff --git a/compiler/compiler.go b/compiler/compiler.go index c5eff978..0c05f253 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2019-2022 Markku Rossi +// Copyright (c) 2019-2023 Markku Rossi // // All rights reserved. // @@ -17,6 +17,7 @@ import ( "github.com/markkurossi/mpc/circuit" "github.com/markkurossi/mpc/compiler/ast" "github.com/markkurossi/mpc/compiler/utils" + "github.com/markkurossi/mpc/ot" "github.com/markkurossi/mpc/p2p" ) @@ -97,7 +98,7 @@ func (c *Compiler) compile(source string, in io.Reader) ( // StreamFile compiles the input program and uses the streaming mode // to garble and stream the circuit to the evaluator node. -func (c *Compiler) StreamFile(conn *p2p.Conn, file string, +func (c *Compiler) StreamFile(conn *p2p.Conn, oti ot.OT, file string, input []string) (circuit.IO, []*big.Int, error) { f, err := os.Open(file) @@ -105,11 +106,11 @@ func (c *Compiler) StreamFile(conn *p2p.Conn, file string, return nil, nil, err } defer f.Close() - return c.stream(conn, file, f, input) + return c.stream(conn, oti, file, f, input) } -func (c *Compiler) stream(conn *p2p.Conn, source string, in io.Reader, - inputFlag []string) (circuit.IO, []*big.Int, error) { +func (c *Compiler) stream(conn *p2p.Conn, oti ot.OT, source string, + in io.Reader, inputFlag []string) (circuit.IO, []*big.Int, error) { timing := circuit.NewTiming() @@ -143,7 +144,7 @@ func (c *Compiler) stream(conn *p2p.Conn, source string, in io.Reader, fmt.Printf(" - Out: %s\n", program.Outputs) fmt.Printf(" - In: %s\n", inputFlag) - return program.StreamCircuit(conn, c.params, input, timing) + return program.StreamCircuit(conn, oti, c.params, input, timing) } func (c *Compiler) parse(source string, in io.Reader, logger *utils.Logger, diff --git a/compiler/ssa/streamer.go b/compiler/ssa/streamer.go index 50b210f5..daad2a4c 100644 --- a/compiler/ssa/streamer.go +++ b/compiler/ssa/streamer.go @@ -1,5 +1,5 @@ // -// Copyright (c) 2020-2022 Markku Rossi +// Copyright (c) 2020-2023 Markku Rossi // // All rights reserved. // @@ -24,8 +24,9 @@ import ( ) // StreamCircuit streams the program circuit into the P2P connection. -func (prog *Program) StreamCircuit(conn *p2p.Conn, params *utils.Params, - inputs *big.Int, timing *circuit.Timing) (circuit.IO, []*big.Int, error) { +func (prog *Program) StreamCircuit(conn *p2p.Conn, oti ot.OT, + params *utils.Params, inputs *big.Int, timing *circuit.Timing) ( + circuit.IO, []*big.Int, error) { var key [32]byte _, err := rand.Read(key[:]) @@ -105,71 +106,20 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, params *utils.Params, timing.Sample("Init", []string{circuit.FileSize(ioStats.Sum()).String()}) // Init oblivious transfer. - sender, err := ot.NewSender(2048) + err = oti.InitSender(conn) if err != nil { return nil, nil, err } - - // Send our public key. - pub := sender.PublicKey() - data := pub.N.Bytes() - if err := conn.SendData(data); err != nil { - return nil, nil, err - } - if err := conn.SendUint32(pub.E); err != nil { - return nil, nil, err - } - conn.Flush() - xfer := conn.Stats.Sub(ioStats) ioStats = conn.Stats timing.Sample("OT Init", []string{circuit.FileSize(xfer.Sum()).String()}) // Peer OTs its inputs. - for i := 0; i < prog.Inputs[1].Size; i++ { - bit, err := conn.ReceiveUint32() - if err != nil { - return nil, nil, err - } - wire := streaming.GetInput(circuit.Wire(bit)) - - var m0Buf, m1Buf ot.LabelData - m0Data := wire.L0.Bytes(&m0Buf) - m1Data := wire.L1.Bytes(&m1Buf) - - xfer, err := sender.NewTransfer(m0Data, m1Data) - if err != nil { - return nil, nil, err - } - - x0, x1 := xfer.RandomMessages() - if err := conn.SendData(x0); err != nil { - return nil, nil, err - } - if err := conn.SendData(x1); err != nil { - return nil, nil, err - } - conn.Flush() - - v, err := conn.ReceiveData() - if err != nil { - return nil, nil, err - } - xfer.ReceiveV(v) - - m0p, m1p, err := xfer.Messages() - if err != nil { - return nil, nil, err - } - if err := conn.SendData(m0p); err != nil { - return nil, nil, err - } - if err := conn.SendData(m1p); err != nil { - return nil, nil, err - } - conn.Flush() + err = oti.Send(streaming.GetInputs(prog.Inputs[0].Size, + prog.Inputs[1].Size)) + if err != nil { + return nil, nil, err } - xfer = conn.Stats.Sub(ioStats) ioStats = conn.Stats timing.Sample("Peer Inputs", @@ -577,7 +527,7 @@ func (prog *Program) StreamCircuit(conn *p2p.Conn, params *utils.Params, } result.SetBit(result, i, bit) } - data = result.Bytes() + data := result.Bytes() if err := conn.SendData(data); err != nil { return nil, nil, err }