diff --git a/mixing/mixclient/blame.go b/mixing/mixclient/blame.go index ccecc2352..81c197cce 100644 --- a/mixing/mixclient/blame.go +++ b/mixing/mixclient/blame.go @@ -11,6 +11,7 @@ import ( "fmt" "math/big" "sort" + "time" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -20,6 +21,8 @@ import ( "github.com/decred/dcrd/wire" ) +var errBlameFailed = errors.New("blame failed") + // blamedIdentities identifies detected misbehaving peers. // // If a run returns a blamedIdentities error, these peers are immediately @@ -48,7 +51,7 @@ func (e blamedIdentities) String() string { } func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { - c.logf("Blaming for sid=%x", sesRun.sid[:]) + sesRun.logf("running blame assignment") mp := c.mixpool prs := sesRun.prs @@ -65,14 +68,11 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { } }() - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - // Send initial secrets messages from any peers who detected - // misbehavior. - if !p.triggeredBlame { - return nil - } - return p.rs - }) + deadline := time.Now().Add(timeoutDuration) + + // Send initial secrets messages from any peers who detected + // misbehavior. + err = c.sendLocalPeerMsgs(ctx, deadline, sesRun, 0) if err != nil { return err } @@ -87,15 +87,17 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { rsHashes = append(rsHashes, rs.Hash()) } - // Send remaining secrets messages. - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.triggeredBlame { - p.triggeredBlame = false - return nil + // Send remaining secrets messages with observed RS hashes from the + // initial peers who published secrets. + c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if !p.triggeredBlame { + if p.rs != nil { + p.rs.SeenSecrets = rsHashes + } } - p.rs.SeenSecrets = rsHashes - return p.rs + return nil }) + err = c.sendLocalPeerMsgs(ctx, deadline, sesRun, msgRS) if err != nil { return err } @@ -113,14 +115,14 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { } if len(rss) != len(sesRun.peers) { // Blame peers who did not send secrets - c.logf("received %d RSs for %d peers; blaming unresponsive peers", + sesRun.logf("received %d RSs for %d peers; blaming unresponsive peers", len(rss), len(sesRun.peers)) for _, p := range sesRun.peers { if p.rs != nil { continue } - c.logf("blaming %x for RS timeout", p.id[:]) + sesRun.logf("blaming %x for RS timeout", p.id[:]) blamed = append(blamed, *p.id) } return blamed @@ -142,10 +144,14 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { continue } id := &rs.Identity - c.logf("blaming %x for false failure accusation", id[:]) + sesRun.logf("blaming %x for false failure accusation", id[:]) blamed = append(blamed, *id) } - err = blamed + if len(blamed) > 0 { + err = blamed + } else { + err = errBlameFailed + } }() defer c.mu.Unlock() @@ -159,7 +165,7 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { KELoop: for _, p := range sesRun.peers { if p.ke == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -169,7 +175,7 @@ KELoop: cm := p.rs.Commitment(c.blake256Hasher) c.blake256HasherMu.Unlock() if cm != p.ke.Commitment { - c.logf("blaming %x for false commitment, got %x want %x", + sesRun.logf("blaming %x for false commitment, got %x want %x", p.id[:], cm[:], p.ke.Commitment[:]) blamed = append(blamed, *p.id) continue @@ -177,7 +183,7 @@ KELoop: // Blame peers whose seed is not the correct length (will panic chacha20prng). if len(p.rs.Seed) != chacha20prng.SeedSize { - c.logf("blaming %x for bad seed size in RS message", p.id[:]) + sesRun.logf("blaming %x for bad seed size in RS message", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -187,7 +193,7 @@ KELoop: if mixing.InField(scratch.SetBytes(m)) { continue } - c.logf("blaming %x for SR message outside field", p.id[:]) + sesRun.logf("blaming %x for SR message outside field", p.id[:]) blamed = append(blamed, *p.id) continue KELoop } @@ -199,7 +205,7 @@ KELoop: // Recover derived key exchange from PRNG. p.kx, err = mixing.NewKX(p.prng) if err != nil { - c.logf("blaming %x for bad KX", p.id[:]) + sesRun.logf("blaming %x for bad KX", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -210,14 +216,14 @@ KELoop: case !bytes.Equal(p.ke.ECDH[:], p.kx.ECDHPublicKey.SerializeCompressed()): fallthrough case !bytes.Equal(p.ke.PQPK[:], p.kx.PQPublicKey[:]): - c.logf("blaming %x for KE public keys not derived from their PRNG", + sesRun.logf("blaming %x for KE public keys not derived from their PRNG", p.id[:]) blamed = append(blamed, *p.id) continue KELoop } publishedECDHPub, err := secp256k1.ParsePubKey(p.ke.ECDH[:]) if err != nil { - c.logf("blaming %x for unparsable pubkey") + sesRun.logf("blaming %x for unparsable pubkey") blamed = append(blamed, *p.id) continue } @@ -229,7 +235,7 @@ KELoop: start += mcount if uint32(len(p.rs.SlotReserveMsgs)) != mcount || uint32(len(p.rs.DCNetMsgs)) != mcount { - c.logf("blaming %x for bad message count", p.id[:]) + sesRun.logf("blaming %x for bad message count", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -261,19 +267,19 @@ KELoop: // from their PRNG. for i, p := range sesRun.peers { if p.ct == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } if len(recoveredCTs[i]) != len(p.ct.Ciphertexts) { - c.logf("blaming %x for different ciphertexts count %d != %d", + sesRun.logf("blaming %x for different ciphertexts count %d != %d", p.id[:], len(recoveredCTs[i]), len(p.ct.Ciphertexts)) blamed = append(blamed, *p.id) continue } for j := range p.ct.Ciphertexts { if !bytes.Equal(p.ct.Ciphertexts[j][:], recoveredCTs[i][j][:]) { - c.logf("blaming %x for different ciphertexts", p.id[:]) + sesRun.logf("blaming %x for different ciphertexts", p.id[:]) blamed = append(blamed, *p.id) break } @@ -294,7 +300,7 @@ KELoop: for _, pids := range shared { if len(pids) > 1 { for i := range pids { - c.logf("blaming %x for shared SR message", pids[i][:]) + sesRun.logf("blaming %x for shared SR message", pids[i][:]) } blamed = append(blamed, pids...) } @@ -306,7 +312,7 @@ KELoop: SRLoop: for i, p := range sesRun.peers { if p.sr == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -325,7 +331,7 @@ SRLoop: var decapErr *mixing.DecapsulateError if errors.As(err, &decapErr) { submittingID := p.id - c.logf("blaming %x for unrecoverable secrets", submittingID[:]) + sesRun.logf("blaming %x for unrecoverable secrets", submittingID[:]) blamed = append(blamed, *submittingID) continue } @@ -343,7 +349,7 @@ SRLoop: // Blame when committed mix does not match provided. for k := range srMix { if srMix[k].Cmp(scratch.SetBytes(p.sr.DCMix[j][k])) != 0 { - c.logf("blaming %x for bad SR mix", p.id[:]) + sesRun.logf("blaming %x for bad SR mix", p.id[:]) blamed = append(blamed, *p.id) continue SRLoop } @@ -376,7 +382,7 @@ DCLoop: // deferred function) if no peers could be assigned blame is // not likely to be seen under this situation. if p.dc == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -386,7 +392,7 @@ DCLoop: // message, and there must be mcount DC-net vectors. mcount := p.pr.MessageCount if uint32(len(p.dc.DCNet)) != mcount { - c.logf("blaming %x for missing DC mix vectors", p.id[:]) + sesRun.logf("blaming %x for missing DC mix vectors", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -406,7 +412,7 @@ DCLoop: // Blame when committed mix does not match provided. for k := 0; k < len(dcMix); k++ { if !dcMix.Equals(mixing.Vec(p.dc.DCNet[j])) { - c.logf("blaming %x for bad DC mix", p.id[:]) + sesRun.logf("blaming %x for bad DC mix", p.id[:]) blamed = append(blamed, *p.id) continue DCLoop } diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index 13b42353e..bf8422afd 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -64,7 +64,18 @@ const ( cmTimeout ) -func blameTimedOut(sesLog *sessionLogger, sesRun *sessionRun, timeoutMessage int) blamedIdentities { +// Constants specifying which peer messages to publish. +const ( + msgKE = 1 << iota + msgCT + msgSR + msgDC + msgFP + msgCM + msgRS +) + +func blameTimedOut(sesRun *sessionRun, timeoutMessage int) error { var blamed blamedIdentities var stage string for _, p := range sesRun.peers { @@ -91,8 +102,11 @@ func blameTimedOut(sesLog *sessionLogger, sesRun *sessionRun, timeoutMessage int } } } - sesLog.logf("blaming %x during run (%s timeout)", []identity(blamed), stage) - return blamed + if len(blamed) > 0 { + sesRun.logf("blaming %x during run (%s timeout)", []identity(blamed), stage) + return blamed + } + return errBlameFailed } // Wallet signs mix transactions and listens for and broadcasts mixing @@ -125,7 +139,6 @@ type Wallet interface { } type deadlines struct { - epoch time.Time recvKE time.Time recvCT time.Time recvSR time.Time @@ -146,41 +159,15 @@ func (d *deadlines) start(begin time.Time) { d.recvCM = add() } -func (d *deadlines) shift() { - d.recvKE = d.recvCT - d.recvCT = d.recvSR - d.recvSR = d.recvDC - d.recvDC = d.recvCM - d.recvCM = d.recvCM.Add(timeoutDuration) -} - -func (d *deadlines) restart() { - d.start(d.recvCM) -} - -// peer represents a participating client in a peer-to-peer mixing session. -// Some fields only pertain to peers created by this wallet, while the rest -// are used during blame assignment. -type peer struct { - ctx context.Context - client *Client - jitter time.Duration - - res chan error - - pub *secp256k1.PublicKey - priv *secp256k1.PrivateKey - id *identity // serialized pubkey - pr *wire.MsgMixPairReq - coinjoin *CoinJoin - kx *mixing.KX - +// peerRunState describes the peer state that changes across different +// sessions/runs. +type peerRunState struct { prngSeed [32]byte prng *chacha20prng.Reader + kx *mixing.KX // derived from PRNG - rs *wire.MsgMixSecrets - srMsg []*big.Int // random numbers for the exponential slot reservation mix - dcMsg wire.MixVect // anonymized messages to publish in XOR mix + srMsg []*big.Int // random (non-PRNG) numbers for the exponential slot reservation mix + dcMsg wire.MixVect // anonymized messages (HASH160s) to publish in XOR mix ke *wire.MsgMixKeyExchange ct *wire.MsgMixCiphertexts @@ -188,8 +175,9 @@ type peer struct { fp *wire.MsgMixFactoredPoly dc *wire.MsgMixDCNet cm *wire.MsgMixConfirm + rs *wire.MsgMixSecrets - // Unmixed positions. May change over multiple sessions/runs. + // Unmixed positions. myVk uint32 myStart uint32 @@ -204,12 +192,54 @@ type peer struct { // Whether peer misbehavior was detected by this peer, and initial // secrets will be revealed by c.blame(). triggeredBlame bool +} + +// peer represents a participating client in a peer-to-peer mixing session. +// Some fields only pertain to peers created by this wallet, while the rest +// are used during blame assignment. +type peer struct { + ctx context.Context + client *Client + jitter time.Duration + + res chan error + + pub *secp256k1.PublicKey + priv *secp256k1.PrivateKey + id *identity // serialized pubkey + pr *wire.MsgMixPairReq + coinjoin *CoinJoin + + peerRunState // Whether this peer represents a remote peer created from revealed secrets; // used during blaming. remote bool } +// cloneLocalPeer creates a new peer instance representing a local peer +// client, sharing the caller context and result channels, but resetting all +// per-session fields (if freshGen is true), or only copying the PRNG and +// fields directly derived from it (when freshGen is false). +func (p *peer) cloneLocalPeer(freshGen bool) *peer { + if p.remote { + panic("cloneLocalPeer: remote peer") + } + + p2 := *p + p2.peerRunState = peerRunState{} + + if !freshGen { + p2.prngSeed = p.prngSeed + p2.prng = p.prng + p2.kx = p.kx + p2.srMsg = p.srMsg + p2.dcMsg = p.dcMsg + } + + return &p2 +} + func newRemotePeer(pr *wire.MsgMixPairReq) *peer { return &peer{ id: &pr.Identity, @@ -229,32 +259,65 @@ func generateSecp256k1() (*secp256k1.PublicKey, *secp256k1.PrivateKey, error) { return publicKey, privateKey, nil } +type pendingPairing struct { + localPeers map[identity]*peer + pairing []byte +} + // pairedSessions tracks the waiting and in-progress mix sessions performed by // one or more local peers using compatible pairings. type pairedSessions struct { localPeers map[identity]*peer pairing []byte runs []sessionRun + + epoch time.Time + deadlines + + // Track state of pairing completion. + // + // An agreed-upon pairing means there was agreement on an initial set + // of peers/KEs. However, various situations (such as blaming + // misbehaving peers, or additional session formations after hitting + // size limits) must not consider the previous sessions with all + // received KEs. The peer agreement index clamps down on this by only + // considering runs in + // pairedSessions.runs[pairedSessions.peerAgreementRunIdx:] + peerAgreementRunIdx int + peerAgreement bool + + donePairingOnce sync.Once } type sessionRun struct { sid [32]byte + idx int // in pairedSessions.runs mtot uint32 // Whether this run must generate fresh KX keys, SR/DC messages. freshGen bool - deadlines - // Peers sorted by PR hashes. Each peer's myVk is its index in this // slice. - prs []*wire.MsgMixPairReq - peers []*peer - mcounts []uint32 - roots []*big.Int + localPeers map[identity]*peer + prs []*wire.MsgMixPairReq + kes []*wire.MsgMixKeyExchange // set by completePairing + peers []*peer + mcounts []uint32 + roots []*big.Int // Finalized coinjoin of a successful run. cj *CoinJoin + + logger slog.Logger +} + +func (s *sessionRun) logf(format string, args ...interface{}) { + if s.logger == nil { + return + } + + s.logger.Debugf("sid=%x/%d "+format, append([]interface{}{s.sid[:], s.idx}, args...)...) } type queueWork struct { @@ -273,9 +336,9 @@ type Client struct { // Pending and active sessions and peers (both local and, when // blaming, remote). - pairings map[string]*pairedSessions - height uint32 - mu sync.Mutex + pendingPairings map[string]*pendingPairing + height uint32 + mu sync.Mutex warming chan struct{} workQueue chan *queueWork @@ -303,15 +366,15 @@ func NewClient(w Wallet) *Client { height, _ := w.BestBlock() return &Client{ - atomicPRFlags: uint32(prFlags), - wallet: w, - mixpool: w.Mixpool(), - pairings: make(map[string]*pairedSessions), - warming: make(chan struct{}), - workQueue: make(chan *queueWork, runtime.NumCPU()), - blake256Hasher: blake256.New(), - epoch: w.Mixpool().Epoch(), - height: height, + atomicPRFlags: uint32(prFlags), + wallet: w, + mixpool: w.Mixpool(), + pendingPairings: make(map[string]*pendingPairing), + warming: make(chan struct{}), + workQueue: make(chan *queueWork, runtime.NumCPU()), + blake256Hasher: blake256.New(), + epoch: w.Mixpool().Epoch(), + height: height, } } @@ -343,23 +406,6 @@ func (c *Client) logerrf(format string, args ...interface{}) { c.logger.Errorf(format, args...) } -func (c *Client) sessionLog(sid [32]byte) *sessionLogger { - return &sessionLogger{sid: sid, logger: c.logger} -} - -type sessionLogger struct { - sid [32]byte - logger slog.Logger -} - -func (l *sessionLogger) logf(format string, args ...interface{}) { - if l.logger == nil { - return - } - - l.logger.Debugf("sid=%x "+format, append([]interface{}{l.sid[:]}, args...)...) -} - // Run runs the client manager, blocking until after the context is // cancelled. func (c *Client) Run(ctx context.Context) error { @@ -426,57 +472,109 @@ func (c *Client) forLocalPeers(ctx context.Context, s *sessionRun, f func(p *pee } type delayedMsg struct { - t time.Time - m mixing.Message - p *peer + sendTime time.Time + deadline time.Time + m mixing.Message + p *peer } -func (c *Client) sendLocalPeerMsgs(ctx context.Context, s *sessionRun, mayTriggerBlame bool, - m func(p *peer) mixing.Message) error { - - msgs := make([]delayedMsg, 0, len(s.peers)) - +func (c *Client) sendLocalPeerMsgs(ctx context.Context, deadline time.Time, s *sessionRun, msgMask uint) error { now := time.Now() + + msgs := make([]delayedMsg, 0, len(s.peers)*bits.OnesCount(msgMask)) for _, p := range s.peers { if p.remote || p.ctx.Err() != nil { continue } - msg := m(p) - if mayTriggerBlame && p.triggeredBlame { - msg = p.rs + msg := delayedMsg{ + sendTime: now.Add(p.msgJitter()), + deadline: deadline, + m: nil, + p: p, + } + msgMask := msgMask + if p.triggeredBlame { + msgMask |= msgRS + } + if msgMask&msgKE == msgKE && p.ke != nil { + msg.m = p.ke + msgs = append(msgs, msg) + } + if msgMask&msgCT == msgCT && p.ct != nil { + msg.m = p.ct + msgs = append(msgs, msg) + } + if msgMask&msgSR == msgSR && p.sr != nil { + msg.m = p.sr + msgs = append(msgs, msg) + } + if msgMask&msgFP == msgFP && p.fp != nil { + msg.m = p.fp + msgs = append(msgs, msg) + } + if msgMask&msgDC == msgDC && p.dc != nil { + msg.m = p.dc + msgs = append(msgs, msg) + } + if msgMask&msgCM == msgCM && p.cm != nil { + msg.m = p.cm + msgs = append(msgs, msg) + } + if msgMask&msgRS == msgRS && p.rs != nil { + msg.m = p.rs + msgs = append(msgs, msg) } - msgs = append(msgs, delayedMsg{ - t: now.Add(p.msgJitter()), - m: msg, - p: p, - }) } - sort.Slice(msgs, func(i, j int) bool { - return msgs[i].t.Before(msgs[j].t) + sort.SliceStable(msgs, func(i, j int) bool { + return msgs[i].sendTime.Before(msgs[j].sendTime) }) - resChans := make([]chan error, 0, len(s.peers)) + nilPeerMsg := func(p *peer, msg mixing.Message) { + switch msg.(type) { + case *wire.MsgMixKeyExchange: + p.ke = nil + case *wire.MsgMixCiphertexts: + p.ct = nil + case *wire.MsgMixSlotReserve: + p.sr = nil + case *wire.MsgMixFactoredPoly: + p.fp = nil + case *wire.MsgMixDCNet: + p.dc = nil + case *wire.MsgMixConfirm: + p.cm = nil + case *wire.MsgMixSecrets: + p.rs = nil + } + } + + errs := make([]error, 0, len(msgs)) + + var sessionCanceledState bool + sessionCanceled := func() { + if sessionCanceledState { + return + } + if err := ctx.Err(); err != nil { + err := fmt.Errorf("session cancelled: %w", err) + errs = append(errs, err) + sessionCanceledState = true + } + } + for i := range msgs { - res := make(chan error, 1) - resChans = append(resChans, res) m := msgs[i] - time.Sleep(time.Until(m.t)) - qsend := &queueWork{ - p: m.p, - f: func(p *peer) error { - return p.signAndSubmit(m.m) - }, - res: res, - } select { case <-ctx.Done(): - res <- ctx.Err() - case c.workQueue <- qsend: + sessionCanceled() + continue + case <-time.After(time.Until(m.sendTime)): + } + err := m.p.signAndSubmit(m.deadline, m.m) + if err != nil { + nilPeerMsg(m.p, m.m) + errs = append(errs, err) } - } - var errs = make([]error, len(resChans)) - for i := range errs { - errs[i] = <-resChans[i] } return errors.Join(errs...) } @@ -545,9 +643,9 @@ func (c *Client) testTick() { c.testTickC <- struct{}{} } -func (c *Client) testHook(stage hook, s *sessionRun, p *peer) { +func (c *Client) testHook(stage hook, ps *pairedSessions, s *sessionRun, p *peer) { if hook, ok := c.testHooks[stage]; ok { - hook(c, s, p) + hook(c, ps, s, p) } } @@ -564,11 +662,13 @@ func (p *peer) signAndHash(m mixing.Message) error { return nil } -func (p *peer) submit(m mixing.Message) error { - return p.client.wallet.SubmitMixMessage(p.ctx, m) +func (p *peer) submit(deadline time.Time, m mixing.Message) error { + ctx, cancel := context.WithDeadline(p.ctx, deadline) + defer cancel() + return p.client.wallet.SubmitMixMessage(ctx, m) } -func (p *peer) signAndSubmit(m mixing.Message) error { +func (p *peer) signAndSubmit(deadline time.Time, m mixing.Message) error { if m == nil { return nil } @@ -576,15 +676,22 @@ func (p *peer) signAndSubmit(m mixing.Message) error { if err != nil { return err } - return p.submit(m) + return p.submit(deadline, m) +} + +func (c *Client) newPendingPairing(pairing []byte) *pendingPairing { + return &pendingPairing{ + localPeers: make(map[identity]*peer), + pairing: pairing, + } } -func (c *Client) newPairings(pairing []byte, peers map[identity]*peer) *pairedSessions { - if peers == nil { - peers = make(map[identity]*peer) +func (c *Client) newPairedSessions(pairing []byte, pendingPeers map[identity]*peer) *pairedSessions { + if pendingPeers == nil { + pendingPeers = make(map[identity]*peer) } ps := &pairedSessions{ - localPeers: peers, + localPeers: pendingPeers, pairing: pairing, runs: nil, } @@ -655,38 +762,40 @@ func (c *Client) epochTicker(ctx context.Context) error { c.mixpool.RemoveConfirmedSessions() c.expireMessages() - for _, p := range c.pairings { + for _, p := range c.pendingPairings { prs := c.mixpool.CompatiblePRs(p.pairing) prsMap := make(map[identity]struct{}) for _, pr := range prs { prsMap[pr.Identity] = struct{}{} } - // Clone the p.localPeers map, only including PRs - // currently accepted to mixpool. Adding additional - // waiting local peers must not add more to the map in - // use by pairSession, and deleting peers in a formed - // session from the pending map must not inadvertently - // remove from pairSession's ps.localPeers map. + // Clone the pending peers map, only including peers + // with PRs currently accepted to mixpool. Adding + // additional waiting local peers must not add more to + // the map in use by pairSession, and deleting peers + // in a formed session from the pending map must not + // inadvertently remove from pairSession's peers map. localPeers := make(map[identity]*peer) for id, peer := range p.localPeers { if _, ok := prsMap[id]; ok { - localPeers[id] = peer + localPeers[id] = peer.cloneLocalPeer(true) } } + + // Even when we know there are not enough total peers + // to meet the minimum peer requirement, a run is + // still formed so that KEs can be broadcast; this + // must be done so peers can fetch missing PRs. + c.logf("Have %d compatible/%d local PRs waiting for pairing %x", len(prs), len(localPeers), p.pairing) - ps := *p - ps.localPeers = localPeers - for id, peer := range p.localPeers { - ps.localPeers[id] = peer - } // pairSession calls Done once the session is formed // and the selected peers have been removed from then // pending pairing. c.pairingWG.Add(1) - go c.pairSession(ctx, &ps, prs, epoch) + ps := c.newPairedSessions(p.pairing, localPeers) + go c.pairSession(ctx, ps, prs, epoch) } c.mu.Unlock() } @@ -742,20 +851,21 @@ func (c *Client) Dicemix(ctx context.Context, cj *CoinJoin) error { c.logf("Created local peer id=%x PR=%s", p.id[:], p.pr.Hash()) c.mu.Lock() - pairing := c.pairings[string(pairingID)] - if pairing == nil { - pairing = c.newPairings(pairingID, nil) - c.pairings[string(pairingID)] = pairing + pending := c.pendingPairings[string(pairingID)] + if pending == nil { + pending = c.newPendingPairing(pairingID) + c.pendingPairings[string(pairingID)] = pending } - pairing.localPeers[*p.id] = p + pending.localPeers[*p.id] = p c.mu.Unlock() - err = p.submit(pr) + deadline := time.Now().Add(timeoutDuration) + err = p.submit(deadline, pr) if err != nil { c.mu.Lock() - delete(pairing.localPeers, *p.id) - if len(pairing.localPeers) == 0 { - delete(c.pairings, string(pairingID)) + delete(pending.localPeers, *p.id) + if len(pending.localPeers) == 0 { + delete(c.pendingPairings, string(pairingID)) } c.mu.Unlock() return err @@ -786,28 +896,37 @@ func (c *Client) ExpireMessages(height uint32) { func (c *Client) expireMessages() { c.mixpool.ExpireMessages(c.height) - for pairID, ps := range c.pairings { - for id, p := range ps.localPeers { - prHash := p.pr.Hash() + for pairID, p := range c.pendingPairings { + for id, peer := range p.localPeers { + prHash := peer.pr.Hash() if !c.mixpool.HaveMessage(&prHash) { - delete(ps.localPeers, id) + delete(p.localPeers, id) // p.res is buffered. If the write is // blocked, we have already served this peer // or sent another error. select { - case p.res <- expiredPRErr(p.pr): + case peer.res <- expiredPRErr(peer.pr): default: } } } - if len(ps.localPeers) == 0 { - delete(c.pairings, pairID) + if len(p.localPeers) == 0 { + delete(c.pendingPairings, pairID) } } } +// donePairing decrements the client pairing waitgroup for the paired +// sessions. This is protected by a sync.Once and safe to call multiple +// times. +func (c *Client) donePairing(ps *pairedSessions) { + ps.donePairingOnce.Do(c.pairingWG.Done) +} + func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wire.MsgMixPairReq, epoch time.Time) { + defer c.donePairing(ps) + // This session pairing attempt, and calling pairSession again with // fresh PRs, must end before the next call to pairSession for this // pairing type. @@ -824,10 +943,9 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir defer func() { c.removeUnresponsiveDuringEpoch(unresponsive, unixEpoch) - unmixedPeers := ps.localPeers if mixedSession != nil && mixedSession.cj != nil { for _, pr := range mixedSession.prs { - delete(unmixedPeers, pr.Identity) + delete(ps.localPeers, pr.Identity) } // XXX: Removing these later in the background is a hack to @@ -836,218 +954,302 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir // for CM messages, which will increment ban score. go func() { time.Sleep(10 * time.Second) - c.logf("sid=%x removing mixed session completed with transaction %v", - mixedSession.sid[:], mixedSession.cj.txHash) + mixedSession.logf("removing mixed session completed "+ + "with transaction %v", mixedSession.cj.txHash) c.mixpool.RemoveSession(mixedSession.sid) }() } - if len(unmixedPeers) == 0 { + if len(ps.localPeers) == 0 { return } - for _, p := range unmixedPeers { - p.ke = nil - p.ct = nil - p.sr = nil - p.fp = nil - p.dc = nil - p.cm = nil - p.rs = nil - } - c.mu.Lock() - pendingPairing := c.pairings[string(ps.pairing)] + pendingPairing := c.pendingPairings[string(ps.pairing)] if pendingPairing == nil { - pendingPairing = c.newPairings(ps.pairing, unmixedPeers) - c.pairings[string(ps.pairing)] = pendingPairing + pendingPairing = c.newPendingPairing(ps.pairing) + c.pendingPairings[string(ps.pairing)] = pendingPairing } else { - for id, p := range unmixedPeers { + for id, p := range ps.localPeers { prHash := p.pr.Hash() if p.ctx.Err() == nil && c.mixpool.HaveMessage(&prHash) { - pendingPairing.localPeers[id] = p + pendingPairing.localPeers[id] = p.cloneLocalPeer(true) } } } c.mu.Unlock() }() - var madePairing bool - defer func() { - if !madePairing { - c.pairingWG.Done() - } - }() + ps.epoch = epoch + ps.deadlines.start(epoch) - var sesLog *sessionLogger - var currentRun *sessionRun - var rerun *sessionRun - var d deadlines - d.epoch = epoch - d.start(epoch) - for { - if rerun == nil { - sid := mixing.SortPRsForSession(prs, unixEpoch) - sesLog = c.sessionLog(sid) - - sesRun := sessionRun{ - sid: sid, - prs: prs, - freshGen: true, - deadlines: d, - mcounts: make([]uint32, 0, len(prs)), - } - ps.runs = append(ps.runs, sesRun) - currentRun = &ps.runs[len(ps.runs)-1] + sid := mixing.SortPRsForSession(prs, unixEpoch) + ps.runs = append(ps.runs, sessionRun{ + sid: sid, + prs: prs, + freshGen: true, + mcounts: make([]uint32, 0, len(prs)), + }) + newRun := &ps.runs[len(ps.runs)-1] - } else { - ps.runs = append(ps.runs, *rerun) - currentRun = &ps.runs[len(ps.runs)-1] - sesLog = c.sessionLog(currentRun.sid) - // rerun is not assigned nil here to please the - // linter. All code paths that reenter this loop will - // set it again. - } + for { + if newRun != nil { + newRun.idx = len(ps.runs) - 1 + newRun.logger = c.logger + + prs = newRun.prs + prHashes := make([]chainhash.Hash, len(prs)) + newRun.localPeers = make(map[identity]*peer) + var m uint32 + var localPeerCount, localUncancelledCount int + for i, pr := range prs { + prHashes[i] = prs[i].Hash() + + // Peer clones must be made from the previous run's + // local peer objects (if any) to preserve existing + // PRNGs and derived secrets and keys. + peerMap := ps.localPeers + if newRun.idx > 0 { + peerMap = ps.runs[newRun.idx-1].localPeers + } + p := peerMap[pr.Identity] + if p != nil { + p = p.cloneLocalPeer(newRun.freshGen) + localPeerCount++ + if p.ctx.Err() == nil { + localUncancelledCount++ + } + newRun.localPeers[*p.id] = p + } else { + p = newRemotePeer(pr) + } + p.myVk = uint32(i) + p.myStart = m - prs = currentRun.prs - prHashes := make([]chainhash.Hash, len(prs)) - for i := range prs { - prHashes[i] = prs[i].Hash() - } + newRun.peers = append(newRun.peers, p) + newRun.mcounts = append(newRun.mcounts, p.pr.MessageCount) - var m uint32 - var localPeerCount, localUncancelledCount int - for i, pr := range prs { - p := ps.localPeers[pr.Identity] - if p != nil { - localPeerCount++ - if p.ctx.Err() == nil { - localUncancelledCount++ - } - } else { - p = newRemotePeer(pr) + m += p.pr.MessageCount } - p.myVk = uint32(i) - p.myStart = m + newRun.mtot = m - currentRun.peers = append(currentRun.peers, p) - currentRun.mcounts = append(currentRun.mcounts, p.pr.MessageCount) + newRun.logf("created session for pairid=%x from %d total %d local PRs %s", + ps.pairing, len(prHashes), localPeerCount, prHashes) - m += p.pr.MessageCount - } - currentRun.mtot = m - - sesLog.logf("created session for pairid=%x from %d total %d local PRs %s", - ps.pairing, len(prHashes), localPeerCount, prHashes) + if localUncancelledCount == 0 { + newRun.logf("no more active local peers") + return + } - if localUncancelledCount == 0 { - sesLog.logf("no more active local peers") - return + newRun = nil } - sesLog.logf("len(ps.runs)=%d", len(ps.runs)) - - c.testHook(hookBeforeRun, currentRun, nil) - err := c.run(ctx, ps, &madePairing) + c.testHook(hookBeforeRun, ps, &ps.runs[len(ps.runs)-1], nil) + r, err := c.run(ctx, ps) + var rerun *sessionRun + var altses *alternateSession var sizeLimitedErr *sizeLimited - if errors.As(err, &sizeLimitedErr) { - if len(sizeLimitedErr.prs) < MinPeers { - sesLog.logf("Aborting session with too few remaining peers") - return - } - - d.shift() + var blamed blamedIdentities + var revealedSecrets bool + var requirePeerAgreement bool + switch { + case errors.Is(err, errOnlyKEsBroadcasted): + // When only KEs are broadcasted, the session was not viable + // due to lack of peers or a peer capable of solving the + // roots. Return without setting the mixed session. + return - sesLog.logf("Recreating as session %x due to standard tx size limits (pairid=%x)", - sizeLimitedErr.sid[:], ps.pairing) + case errors.As(err, &altses): + // If this errored or has too few peers, keep + // retrying previous attempts until next epoch, + // instead of just going away. + if altses.err != nil { + r.logf("Unable to recreate session: %v", altses.err) + ps.deadlines.start(time.Now()) + continue + } - rerun = &sessionRun{ - sid: sizeLimitedErr.sid, - prs: sizeLimitedErr.prs, - freshGen: false, - deadlines: d, + if r.sid != altses.sid { + r.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) + unresponsive = append(unresponsive, altses.unresponsive...) } - continue - } - var altses *alternateSession - if errors.As(err, &altses) { - if altses.err != nil { - sesLog.logf("Unable to recreate session: %v", altses.err) - return + // Required minimum run index is not incremented for + // reformed sessions without peer agreement: we must + // consider receiving KEs from peers who reformed into + // the same session we previously attempted. + rerun = &sessionRun{ + sid: altses.sid, + prs: altses.prs, + freshGen: false, } + err = nil - if len(altses.prs) < MinPeers { - sesLog.logf("Aborting session with too few remaining peers") + case errors.As(err, &sizeLimitedErr): + if len(sizeLimitedErr.prs) < MinPeers { + r.logf("Aborting session with too few remaining peers") return } - d.shift() - - if currentRun.sid != altses.sid { - sesLog.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) - unresponsive = append(unresponsive, altses.unresponsive...) - } + r.logf("Recreating as session %x due to standard tx size limits (pairid=%x)", + sizeLimitedErr.sid[:], ps.pairing) rerun = &sessionRun{ - sid: altses.sid, - prs: altses.prs, - freshGen: false, - deadlines: d, + sid: sizeLimitedErr.sid, + prs: sizeLimitedErr.prs, + freshGen: false, } - continue - } + requirePeerAgreement = true + err = nil - var blamed blamedIdentities - revealedSecrets := false - if errors.Is(err, errTriggeredBlame) || errors.Is(err, mixpool.ErrSecretsRevealed) { + case errors.Is(err, errTriggeredBlame) || errors.Is(err, mixpool.ErrSecretsRevealed): revealedSecrets = true - err := c.blame(ctx, currentRun) + err := c.blame(ctx, r) if !errors.As(err, &blamed) { - sesLog.logf("Aborting session for failed blame assignment: %v", err) + r.logf("Aborting session for failed blame assignment: %v", err) return } + requirePeerAgreement = true + // err = nil would be an ineffectual assignment here; + // blamed is non-nil and the following if block will + // always be entered. } + if blamed != nil || errors.As(err, &blamed) { - sesLog.logf("Identified %d blamed peers %x", len(blamed), []identity(blamed)) + r.logf("Identified %d blamed peers %x", len(blamed), []identity(blamed)) + + if len(r.prs)-len(blamed) < MinPeers { + r.logf("Aborting session with too few remaining peers") + return + } // Blamed peers were identified, either during the run // in a way that all participants could have observed, // or following revealing secrets and blame // assignment. Begin a rerun excluding these peers. - rerun = excludeBlamed(currentRun, blamed, revealedSecrets) + rerun = excludeBlamed(r, unixEpoch, blamed, revealedSecrets) + requirePeerAgreement = true + err = nil + } + + if rerun != nil { + if ps.runs[len(ps.runs)-1].sid == rerun.sid { + r := &ps.runs[len(ps.runs)-1] + r.logf("recreated session matches previous try; " + + "not creating new run states") + if ps.peerAgreementRunIdx >= len(ps.runs) { + // XXX shouldn't happen but fix up anyways + r.logf("reverting incremented peer agreement index") + ps.peerAgreementRunIdx = r.idx + } + } else { + ps.runs = append(ps.runs, *rerun) + newRun = &ps.runs[len(ps.runs)-1] + } + ps.deadlines.start(time.Now()) + if requirePeerAgreement { + ps.peerAgreementRunIdx = len(ps.runs) - 1 + } continue } - // When only KEs are broadcasted, the session was not viable - // due to lack of peers or a peer capable of solving the - // roots. Return without setting the mixed session. - if errors.Is(err, errOnlyKEsBroadcasted) { - return - } - // Any other run error is not actionable. if err != nil { - sesLog.logf("Run error: %v", err) + r.logf("Run error: %v", err) return } - mixedSession = currentRun + mixedSession = r return } } -func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) error { +var errIncompletePairing = errors.New("incomplete pairing") + +// completePairing waits for all KEs to form a completed session. Completed +// pairings are checked for in the order the sessions were attempted. +func (c *Client) completePairing(ctx context.Context, ps *pairedSessions) (*sessionRun, error) { + mp := c.mixpool + res := make(chan *sessionRun, len(ps.runs)) + errs := make(chan error, len(ps.runs)) + var errCount int + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + wrappedCtx, cancel := context.WithCancel(ctx) + defer cancel() + + recvKEs := func(ctx context.Context, sesRun *sessionRun) ([]*wire.MsgMixKeyExchange, error) { + rcv := new(mixpool.Received) + rcv.Sid = sesRun.sid + rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) + ctx, cancel := context.WithDeadline(ctx, ps.deadlines.recvKE) + defer cancel() + err := mp.Receive(ctx, rcv) + if len(rcv.KEs) == len(sesRun.prs) { + return rcv.KEs, nil + } + if err == nil { + err = errIncompletePairing + } + return nil, err + } + + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + + for i := ps.peerAgreementRunIdx; i < len(ps.runs); i++ { + sr := &ps.runs[i] + + // Initially check the mempool with the pre-canceled context + // to return immediately with KEs received so far. + kes, err := recvKEs(canceledCtx, sr) + if err == nil { + sr.kes = kes + return sr, nil + } + + wg.Add(1) + go func() { + defer wg.Done() + kes, err := recvKEs(wrappedCtx, sr) + if err == nil { + sr.kes = kes + res <- sr + } else { + errs <- err + } + }() + } + + for { + select { + case sesRun := <-res: + return sesRun, nil + case err := <-errs: + errCount++ + if errCount == len(ps.runs[ps.peerAgreementRunIdx:]) { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (c *Client) run(ctx context.Context, ps *pairedSessions) (sesRun *sessionRun, err error) { var blamed blamedIdentities mp := c.wallet.Mixpool() - sesRun := &ps.runs[len(ps.runs)-1] + sesRun = &ps.runs[len(ps.runs)-1] prs := sesRun.prs - d := &sesRun.deadlines - unixEpoch := uint64(d.epoch.Unix()) - - sesLog := c.sessionLog(sesRun.sid) + d := &ps.deadlines + unixEpoch := uint64(ps.epoch.Unix()) // A map of identity public keys to their PR position sort all // messages in the same order as the PRs are ordered. @@ -1056,15 +1258,11 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) identityIndices[pr.Identity] = i } - seenPRs := make([]chainhash.Hash, len(prs)) - for i := range prs { - seenPRs[i] = prs[i].Hash() - } - - err := c.forLocalPeers(ctx, sesRun, func(p *peer) error { - p.coinjoin.resetUnmixed(prs) + freshGen := sesRun.freshGen + err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if freshGen { + p.ke = nil - if sesRun.freshGen { // Generate a new PRNG seed rand.Read(p.prngSeed[:]) p.prng = chacha20prng.New(p.prngSeed[:], 0) @@ -1102,45 +1300,54 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } - // Perform key exchange - srMsgBytes := make([][]byte, len(p.srMsg)) - for i := range p.srMsg { - srMsgBytes[i] = p.srMsg[i].Bytes() - } - rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0, - p.prngSeed, srMsgBytes, p.dcMsg) - c.blake256HasherMu.Lock() - commitment := rs.Commitment(c.blake256Hasher) - c.blake256HasherMu.Unlock() - ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) - pqPub := *p.kx.PQPublicKey - ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0, - uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment, - seenPRs) - - p.ke = ke - p.rs = rs + if p.ke == nil { + seenPRs := make([]chainhash.Hash, len(prs)) + for i := range prs { + seenPRs[i] = prs[i].Hash() + } + + // Perform key exchange + srMsgBytes := make([][]byte, len(p.srMsg)) + for i := range p.srMsg { + srMsgBytes[i] = p.srMsg[i].Bytes() + } + rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0, + p.prngSeed, srMsgBytes, p.dcMsg) + c.blake256HasherMu.Lock() + commitment := rs.Commitment(c.blake256Hasher) + c.blake256HasherMu.Unlock() + ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) + pqPub := *p.kx.PQPublicKey + ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0, + uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment, + seenPRs) + + p.ke = ke + p.ct = nil + p.sr = nil + p.fp = nil + p.dc = nil + p.cm = nil + p.rs = rs + } return nil }) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) + } else { + sesRun.freshGen = false } - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.ke == nil { - return nil - } - return p.ke - }) + err = c.sendLocalPeerMsgs(ctx, ps.deadlines.recvKE, sesRun, msgKE) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } // Only continue attempting to form the session if there are minimum // peers available and at least one of them is capable of solving the // roots. if len(prs) < MinPeers { - sesLog.logf("Pairing %x: minimum peer requirement unmet", ps.pairing) - return errOnlyKEsBroadcasted + sesRun.logf("pairing %x: minimum peer requirement unmet", ps.pairing) + return sesRun, errOnlyKEsBroadcasted } haveSolver := false for _, pr := range prs { @@ -1150,8 +1357,8 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if !haveSolver { - sesLog.logf("Pairing %x: no solver available", ps.pairing) - return errOnlyKEsBroadcasted + sesRun.logf("pairing %x: no solver available", ps.pairing) + return sesRun, errOnlyKEsBroadcasted } // Receive key exchange messages. @@ -1168,57 +1375,51 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // All KE messages that match the pairing ID are received, and each // seen PRs slice is checked. PRs that were never followed up by a KE // are immediately excluded. - var kes []*wire.MsgMixKeyExchange - recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) { - rcv := new(mixpool.Received) - rcv.Sid = sesRun.sid - rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) - ctx, cancel := context.WithDeadline(ctx, d.recvKE) - defer cancel() - err = mp.Receive(ctx, rcv) - if ctx.Err() != nil { - err = fmt.Errorf("session %x KE receive context cancelled: %w", - sesRun.sid[:], ctx.Err()) - } - return rcv.KEs, err - } - - switch { - case !*madePairing: - // Receive KEs for the last attempted session. Local - // peers may have been modified (new keys generated, and myVk - // indexes changed) if this is a recreated session, and we - // cannot continue mix using these messages. - // - // XXX: do we need to keep peer info available for previous - // session attempts? It is possible that previous sessions - // may be operable now if all wallets have come to agree on a - // previous session we also tried to form. - kes, err = recvKEs(sesRun) - if err == nil && len(kes) == len(sesRun.prs) { - break - } - - // Alternate session needs to be attempted. Do not form an + completedSesRun, err := c.completePairing(ctx, ps) + if err != nil { + // Alternate session may need to be attempted. Do not form an // alternate session if we are about to enter into the next // epoch. The session forming will be performed by a new // goroutine started by the epoch ticker, possibly with // additional PRs. - nextEpoch := d.epoch.Add(c.epoch) + nextEpoch := ps.epoch.Add(c.epoch) if time.Now().Add(timeoutDuration).After(nextEpoch) { c.logf("Aborting session %x after %d attempts", sesRun.sid[:], len(ps.runs)) - return errOnlyKEsBroadcasted + return sesRun, errOnlyKEsBroadcasted } - return c.alternateSession(ps.pairing, sesRun.prs, d) + // If peer agreement was never established, alternate sessions + // based on the seen PRs must be formed. + if !ps.peerAgreement { + return sesRun, c.alternateSession(ps, sesRun.prs) + } - default: - kes, err = recvKEs(sesRun) - if err != nil { - return err + return sesRun, err + } + + if completedSesRun != sesRun { + completedSesRun.logf("replacing previous session attempt %x/%d", + sesRun.sid[:], sesRun.idx) + + // Reset variables created from assuming the + // final session run. + sesRun = completedSesRun + prs = sesRun.prs + identityIndices = make(map[identity]int) + for i, pr := range prs { + identityIndices[pr.Identity] = i } } + kes := sesRun.kes + sesRun.logf("received all %d KEs", len(kes)) + + // The coinjoin structure is commonly referenced by all instances of + // the local peers; reset all of them from the initial paired session + // attempt, and not only those in the last session run. + for _, p := range ps.localPeers { + p.coinjoin.resetUnmixed(prs) + } // Before confirming the pairing, check all of the agreed-upon PRs // that they will not result in a coinjoin transaction that exceeds @@ -1227,7 +1428,10 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // PRs are randomly ordered in each epoch based on the session ID, so // they can be iterated in order to discover any PR that would // increase the final coinjoin size above the limits. - if !*madePairing { + if !ps.peerAgreement { + ps.peerAgreement = true + ps.peerAgreementRunIdx = sesRun.idx + var sizeExcluded []*wire.MsgMixPairReq var cjSize coinjoinSize for _, pr := range sesRun.prs { @@ -1256,7 +1460,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } sid := mixing.SortPRsForSession(kept, unixEpoch) - return &sizeLimited{ + return sesRun, &sizeLimited{ prs: kept, sid: sid, excluded: sizeExcluded, @@ -1270,22 +1474,23 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } - // Remove paired local peers from waiting pairing. - if !*madePairing { - c.mu.Lock() - if waiting := c.pairings[string(ps.pairing)]; waiting != nil { - for id := range ps.localPeers { - delete(waiting.localPeers, id) - } - if len(waiting.localPeers) == 0 { - delete(c.pairings, string(ps.pairing)) - } + // Remove paired local peers from pending pairings. + // + // XXX might want to keep these instead of racing to add them back if + // this mix doesn't run to completion, and we start next epoch without + // some of our own peers. + c.mu.Lock() + if pending := c.pendingPairings[string(ps.pairing)]; pending != nil { + for id := range sesRun.localPeers { + delete(pending.localPeers, id) + } + if len(pending.localPeers) == 0 { + delete(c.pendingPairings, string(ps.pairing)) } - c.mu.Unlock() - - *madePairing = true - c.pairingWG.Done() } + c.mu.Unlock() + + c.donePairing(ps) sort.Slice(kes, func(i, j int) bool { a := identityIndices[kes[i].Identity] @@ -1299,13 +1504,13 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) involvesLocalPeers := false for _, ke := range kes { - if ps.localPeers[ke.Identity] != nil { + if sesRun.localPeers[ke.Identity] != nil { involvesLocalPeers = true break } } if !involvesLocalPeers { - return errors.New("excluded all local peers") + return sesRun, errors.New("excluded all local peers") } ecdhPublicKeys := make([]*secp256k1.PublicKey, 0, len(prs)) @@ -1320,11 +1525,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) pqpk = append(pqpk, &ke.PQPK) } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (invalid ECDH pubkeys)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (invalid ECDH pubkeys)", []identity(blamed)) + return sesRun, blamed } err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.ct != nil { + return nil + } + // Create shared keys and ciphextexts for each peer pqct, err := p.kx.Encapsulate(p.prng, pqpk, int(p.myVk)) if err != nil { @@ -1334,20 +1543,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Send ciphertext messages ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, 0, pqct, seenKEs) p.ct = ct - c.testHook(hookBeforePeerCTPublish, sesRun, p) + c.testHook(hookBeforePeerCTPublish, ps, sesRun, p) return nil }) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.ct == nil { - return nil - } - return p.ct - }) + err = c.sendLocalPeerMsgs(ctx, d.recvCT, sesRun, msgCT) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } // Receive all ciphertext messages @@ -1365,12 +1569,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(cts) != len(prs) { // Blame peers - sesLog.logf("Received %d CTs for %d peers; rerunning", len(cts), len(prs)) - return blameTimedOut(sesLog, sesRun, ctTimeout) + sesRun.logf("received %d CTs for %d peers; rerunning", len(cts), len(prs)) + return sesRun, blameTimedOut(sesRun, ctTimeout) } sort.Slice(cts, func(i, j int) bool { a := identityIndices[cts[i].Identity] @@ -1385,6 +1589,10 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) blamedMap := make(map[identity]struct{}) var blamedMapMu sync.Mutex err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.sr != nil { + return nil + } + revealed := &mixing.RevealedKeys{ ECDHPublicKeys: ecdhPublicKeys, Ciphertexts: make([]mixing.PQCiphertext, 0, len(prs)), @@ -1422,28 +1630,23 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // reservations. sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, 0, srMixBytes, seenCTs) p.sr = sr - c.testHook(hookBeforePeerSRPublish, sesRun, p) + c.testHook(hookBeforePeerSRPublish, ps, sesRun, p) return nil }) if len(blamedMap) > 0 { for id := range blamedMap { blamed = append(blamed, id) } - sesLog.logf("blaming %x during run (wrong ciphertext count)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (wrong ciphertext count)", []identity(blamed)) + return sesRun, blamed } - sendErr := c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.sr == nil { - return nil - } - return p.sr - }) + sendErr := c.sendLocalPeerMsgs(ctx, d.recvSR, sesRun, msgSR) if sendErr != nil { - sesLog.logf("%v", sendErr) + sesRun.logf("%v", sendErr) } if err != nil { - sesLog.logf("%v", err) - return err + sesRun.logf("%v", err) + return sesRun, err } // Receive all slot reservation messages @@ -1459,12 +1662,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(srs) != len(prs) { // Blame peers - sesLog.logf("Received %d SRs for %d peers; rerunning", len(srs), len(prs)) - return blameTimedOut(sesLog, sesRun, srTimeout) + sesRun.logf("received %d SRs for %d peers; rerunning", len(srs), len(prs)) + return sesRun, blameTimedOut(sesRun, srTimeout) } sort.Slice(srs, func(i, j int) bool { a := identityIndices[srs[i].Identity] @@ -1484,16 +1687,18 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) powerSums := mixing.AddVectors(mixing.IntVectorsFromBytes(vs)...) coeffs := mixing.Coefficients(powerSums) rcvCtx, rcvCtxCancel = context.WithDeadline(ctx, d.recvSR) - publishedRootsC := make(chan struct{}) - defer func() { <-publishedRootsC }() - roots, err := c.roots(rcvCtx, seenSRs, sesRun, coeffs, mixing.F, publishedRootsC) + roots, err := c.roots(rcvCtx, seenSRs, sesRun, coeffs, mixing.F) rcvCtxCancel() if err != nil { - return err + return sesRun, err } sesRun.roots = roots err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.dc != nil { + return nil + } + // Find reserved slots slots := make([]uint32, 0, p.pr.MessageCount) for _, m := range p.srMsg { @@ -1516,21 +1721,16 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Broadcast XOR DC-net vectors. dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, 0, p.dcNet, seenSRs) p.dc = dc - c.testHook(hookBeforePeerDCPublish, sesRun, p) + c.testHook(hookBeforePeerDCPublish, ps, sesRun, p) return nil }) - sendErr = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.dc == nil { - return nil - } - return p.dc - }) + sendErr = c.sendLocalPeerMsgs(ctx, d.recvDC, sesRun, msgFP|msgDC) if sendErr != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } if err != nil { - sesLog.logf("DC-net error: %v", err) - return err + sesRun.logf("DC-net error: %v", err) + return sesRun, err } // Receive all DC messages @@ -1546,12 +1746,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(dcs) != len(prs) { // Blame peers - sesLog.logf("Received %d DCs for %d peers; rerunning", len(dcs), len(prs)) - return blameTimedOut(sesLog, sesRun, dcTimeout) + sesRun.logf("received %d DCs for %d peers; rerunning", len(dcs), len(prs)) + return sesRun, blameTimedOut(sesRun, dcTimeout) } sort.Slice(dcs, func(i, j int) bool { a := identityIndices[dcs[i].Identity] @@ -1575,11 +1775,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (wrong DC-net count)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (wrong DC-net count)", []identity(blamed)) + return sesRun, blamed } mixedMsgs := mixing.XorVectors(dcVecs) err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.cm != nil { + return nil + } + // Add outputs for each mixed message for i := range mixedMsgs { mixedMsg := mixedMsgs[i][:] @@ -1591,7 +1795,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // provided inputs. err := p.coinjoin.confirm(c.wallet) if errors.Is(err, errMissingGen) { - sesLog.logf("Missing message; blaming and rerunning") + sesRun.logf("missing message; blaming and rerunning") p.triggeredBlame = true return errTriggeredBlame } @@ -1606,18 +1810,13 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) p.cm = cm return nil }) - sendErr = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.cm == nil { - return nil - } - return p.cm - }) + sendErr = c.sendLocalPeerMsgs(ctx, d.recvCM, sesRun, msgCM) if sendErr != nil { - sesLog.logf("%v", sendErr) + sesRun.logf("%v", sendErr) } if err != nil { - sesLog.logf("Confirm error: %v", err) - return err + sesRun.logf("confirm error: %v", err) + return sesRun, err } // Receive all CM messages @@ -1633,12 +1832,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(cms) != len(prs) { // Blame peers - sesLog.logf("Received %d CMs for %d peers; rerunning", len(cms), len(prs)) - return blameTimedOut(sesLog, sesRun, cmTimeout) + sesRun.logf("received %d CMs for %d peers; rerunning", len(cms), len(prs)) + return sesRun, blameTimedOut(sesRun, cmTimeout) } sort.Slice(cms, func(i, j int) bool { a := identityIndices[cms[i].Identity] @@ -1670,18 +1869,18 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (confirmed wrong coinjoin)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (confirmed wrong coinjoin)", []identity(blamed)) + return sesRun, blamed } err = c.validateMergedCoinjoin(cj, prs, utxos) if err != nil { - return err + return sesRun, err } time.Sleep(lowestJitter + rand.Duration(msgJitter)) err = c.wallet.PublishTransaction(context.Background(), cj.tx) if err != nil { - return err + return sesRun, err } c.forLocalPeers(ctx, sesRun, func(p *peer) error { @@ -1694,7 +1893,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) sesRun.cj = cj - return nil + return sesRun, nil } func (c *Client) solvesRoots() bool { @@ -1707,7 +1906,7 @@ func (c *Client) solvesRoots() bool { // the result to all peers. If the client is incapable of solving the roots, // it waits for a solution. func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, - sesRun *sessionRun, a []*big.Int, F *big.Int, publishedRoots chan struct{}) ([]*big.Int, error) { + sesRun *sessionRun, a []*big.Int, F *big.Int) ([]*big.Int, error) { switch { case c.solvesRoots(): @@ -1728,7 +1927,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, p.triggeredBlame = true return nil }) - close(publishedRoots) return nil, errTriggeredBlame } sort.Slice(roots, func(i, j int) bool { @@ -1739,25 +1937,16 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, rootBytes[i] = root.Bytes() } c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.fp != nil { + return nil + } p.fp = wire.NewMsgMixFactoredPoly(*p.id, sesRun.sid, 0, rootBytes, seenSRs) return nil }) - // Don't wait for these messages to send. - go func() { - err := c.sendLocalPeerMsgs(ctx, sesRun, false, func(p *peer) mixing.Message { - return p.fp - }) - if ctx.Err() == nil && err != nil { - c.logf("sid=%x %v", sesRun.sid[:], err) - } - publishedRoots <- struct{}{} - }() return roots, nil } - close(publishedRoots) - // Clients unable to solve their own roots must wait for solutions. // We can return a result as soon as we read any valid factored // polynomial message that provides the solutions for this SR @@ -1780,6 +1969,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, if _, ok := checkedFPByIdentity[fp.Identity]; ok { continue } + checkedFPByIdentity[fp.Identity] = struct{}{} roots = roots[:0] duplicateRoots := make(map[string]struct{}) @@ -1814,8 +2004,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, if len(roots) == len(a)-1 { return roots, nil } - - checkedFPByIdentity[fp.Identity] = struct{}{} } rcv.FPs = make([]*wire.MsgMixFactoredPoly, 0, len(rcv.FPs)+1) @@ -1928,10 +2116,10 @@ func (e *alternateSession) Unwrap() error { return e.err } -func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, d *deadlines) *alternateSession { - unixEpoch := uint64(d.epoch.Unix()) +func (c *Client) alternateSession(ps *pairedSessions, prs []*wire.MsgMixPairReq) *alternateSession { + unixEpoch := uint64(ps.epoch.Unix()) - kes := c.mixpool.ReceiveKEsByPairing(pairing, unixEpoch) + kes := c.mixpool.ReceiveKEsByPairing(ps.pairing, unixEpoch) // Sort KEs by identity first (just to group these together) followed // by the total referenced PR counts in increasing order (most recent @@ -2112,7 +2300,7 @@ func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, d * } -func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets bool) *sessionRun { +func excludeBlamed(prevRun *sessionRun, epoch uint64, blamed blamedIdentities, revealedSecrets bool) *sessionRun { blamedMap := make(map[identity]struct{}) for _, id := range blamed { blamedMap[id] = struct{}{} @@ -2132,19 +2320,15 @@ func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets prs = append(prs, p.pr) } - d := prevRun.deadlines - d.restart() - - unixEpoch := prevRun.epoch.Unix() - sid := mixing.SortPRsForSession(prs, uint64(unixEpoch)) + sid := mixing.SortPRsForSession(prs, epoch) // mtot, peers, mcounts are all recalculated from the prs before // calling run() nextRun := &sessionRun{ - sid: sid, - freshGen: revealedSecrets, - prs: prs, - deadlines: d, + sid: sid, + idx: prevRun.idx + 1, + freshGen: revealedSecrets, + prs: prs, } return nextRun } diff --git a/mixing/mixclient/client_test.go b/mixing/mixclient/client_test.go index a267b7060..1f4d88d0e 100644 --- a/mixing/mixclient/client_test.go +++ b/mixing/mixclient/client_test.go @@ -267,8 +267,8 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { w := newTestWallet(bc) c := newTestClient(w, l) c.testHooks = map[hook]hookFunc{ - hookBeforeRun: func(c *Client, s *sessionRun, _ *peer) { - s.deadlines.recvKE = time.Now().Add(5 * time.Second) + hookBeforeRun: func(c *Client, ps *pairedSessions, _ *sessionRun, _ *peer) { + ps.deadlines.recvKE = time.Now().Add(5 * time.Second) }, h: f, } @@ -397,73 +397,78 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { func TestCTDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping CT bit", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts[1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping CT bit", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts[1][0] ^= 1 + }) } func TestCTLength(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1] - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1] + }) misbehavingID = identity{} - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts = append(p.ct.Ciphertexts, p.ct.Ciphertexts[0]) - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts = append(p.ct.Ciphertexts, p.ct.Ciphertexts[0]) + }) } func TestSRDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping SR bit", p.id[:]) - misbehavingID = *p.id - p.sr.DCMix[0][1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping SR bit", p.id[:]) + misbehavingID = *p.id + p.sr.DCMix[0][1][0] ^= 1 + }) } func TestDCDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping DC bit", p.id[:]) - misbehavingID = *p.id - p.dc.DCNet[0][1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping DC bit", p.id[:]) + misbehavingID = *p.id + p.dc.DCNet[0][1][0] ^= 1 + }) } diff --git a/mixing/mixclient/testhooks.go b/mixing/mixclient/testhooks.go index de19db94d..d564e2e3c 100644 --- a/mixing/mixclient/testhooks.go +++ b/mixing/mixclient/testhooks.go @@ -6,7 +6,7 @@ package mixclient type hook string -type hookFunc func(*Client, *sessionRun, *peer) +type hookFunc func(*Client, *pairedSessions, *sessionRun, *peer) const ( hookBeforeRun hook = "before run"