From d81384ab93b0620e7e97370908ac5048e3b67e69 Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Thu, 23 Jan 2025 21:17:15 +0000 Subject: [PATCH] mixclient: Wait for KEs from all attempted sessions This commit modifies the mixing client to consider fully formed sessions (where all peers have sent a KE with matching sessions), even if other sessions have already been attempted after. Previously, only the most recently attempted session would ever be used, and clients would only ratchet-down the peers they would mix with. With this change, initial session disagreement can be recovered from more gracefully, as some peers will form an alternate session matching the session that other peers originally tried. This commit also introduces a fix for a race that could result in root-solving wallets not publishing their solutions in a timely manner, which would result in non-solving wallets eventually being blamed for DC timeout. --- mixing/mixclient/blame.go | 82 ++- mixing/mixclient/client.go | 1146 ++++++++++++++++++------------- mixing/mixclient/client_test.go | 119 ++-- mixing/mixclient/testhooks.go | 2 +- 4 files changed, 772 insertions(+), 577 deletions(-) diff --git a/mixing/mixclient/blame.go b/mixing/mixclient/blame.go index ccecc2352f..81c197ccef 100644 --- a/mixing/mixclient/blame.go +++ b/mixing/mixclient/blame.go @@ -11,6 +11,7 @@ import ( "fmt" "math/big" "sort" + "time" "github.com/decred/dcrd/chaincfg/chainhash" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -20,6 +21,8 @@ import ( "github.com/decred/dcrd/wire" ) +var errBlameFailed = errors.New("blame failed") + // blamedIdentities identifies detected misbehaving peers. // // If a run returns a blamedIdentities error, these peers are immediately @@ -48,7 +51,7 @@ func (e blamedIdentities) String() string { } func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { - c.logf("Blaming for sid=%x", sesRun.sid[:]) + sesRun.logf("running blame assignment") mp := c.mixpool prs := sesRun.prs @@ -65,14 +68,11 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { } }() - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - // Send initial secrets messages from any peers who detected - // misbehavior. - if !p.triggeredBlame { - return nil - } - return p.rs - }) + deadline := time.Now().Add(timeoutDuration) + + // Send initial secrets messages from any peers who detected + // misbehavior. + err = c.sendLocalPeerMsgs(ctx, deadline, sesRun, 0) if err != nil { return err } @@ -87,15 +87,17 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { rsHashes = append(rsHashes, rs.Hash()) } - // Send remaining secrets messages. - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.triggeredBlame { - p.triggeredBlame = false - return nil + // Send remaining secrets messages with observed RS hashes from the + // initial peers who published secrets. + c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if !p.triggeredBlame { + if p.rs != nil { + p.rs.SeenSecrets = rsHashes + } } - p.rs.SeenSecrets = rsHashes - return p.rs + return nil }) + err = c.sendLocalPeerMsgs(ctx, deadline, sesRun, msgRS) if err != nil { return err } @@ -113,14 +115,14 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { } if len(rss) != len(sesRun.peers) { // Blame peers who did not send secrets - c.logf("received %d RSs for %d peers; blaming unresponsive peers", + sesRun.logf("received %d RSs for %d peers; blaming unresponsive peers", len(rss), len(sesRun.peers)) for _, p := range sesRun.peers { if p.rs != nil { continue } - c.logf("blaming %x for RS timeout", p.id[:]) + sesRun.logf("blaming %x for RS timeout", p.id[:]) blamed = append(blamed, *p.id) } return blamed @@ -142,10 +144,14 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { continue } id := &rs.Identity - c.logf("blaming %x for false failure accusation", id[:]) + sesRun.logf("blaming %x for false failure accusation", id[:]) blamed = append(blamed, *id) } - err = blamed + if len(blamed) > 0 { + err = blamed + } else { + err = errBlameFailed + } }() defer c.mu.Unlock() @@ -159,7 +165,7 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { KELoop: for _, p := range sesRun.peers { if p.ke == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -169,7 +175,7 @@ KELoop: cm := p.rs.Commitment(c.blake256Hasher) c.blake256HasherMu.Unlock() if cm != p.ke.Commitment { - c.logf("blaming %x for false commitment, got %x want %x", + sesRun.logf("blaming %x for false commitment, got %x want %x", p.id[:], cm[:], p.ke.Commitment[:]) blamed = append(blamed, *p.id) continue @@ -177,7 +183,7 @@ KELoop: // Blame peers whose seed is not the correct length (will panic chacha20prng). if len(p.rs.Seed) != chacha20prng.SeedSize { - c.logf("blaming %x for bad seed size in RS message", p.id[:]) + sesRun.logf("blaming %x for bad seed size in RS message", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -187,7 +193,7 @@ KELoop: if mixing.InField(scratch.SetBytes(m)) { continue } - c.logf("blaming %x for SR message outside field", p.id[:]) + sesRun.logf("blaming %x for SR message outside field", p.id[:]) blamed = append(blamed, *p.id) continue KELoop } @@ -199,7 +205,7 @@ KELoop: // Recover derived key exchange from PRNG. p.kx, err = mixing.NewKX(p.prng) if err != nil { - c.logf("blaming %x for bad KX", p.id[:]) + sesRun.logf("blaming %x for bad KX", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -210,14 +216,14 @@ KELoop: case !bytes.Equal(p.ke.ECDH[:], p.kx.ECDHPublicKey.SerializeCompressed()): fallthrough case !bytes.Equal(p.ke.PQPK[:], p.kx.PQPublicKey[:]): - c.logf("blaming %x for KE public keys not derived from their PRNG", + sesRun.logf("blaming %x for KE public keys not derived from their PRNG", p.id[:]) blamed = append(blamed, *p.id) continue KELoop } publishedECDHPub, err := secp256k1.ParsePubKey(p.ke.ECDH[:]) if err != nil { - c.logf("blaming %x for unparsable pubkey") + sesRun.logf("blaming %x for unparsable pubkey") blamed = append(blamed, *p.id) continue } @@ -229,7 +235,7 @@ KELoop: start += mcount if uint32(len(p.rs.SlotReserveMsgs)) != mcount || uint32(len(p.rs.DCNetMsgs)) != mcount { - c.logf("blaming %x for bad message count", p.id[:]) + sesRun.logf("blaming %x for bad message count", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -261,19 +267,19 @@ KELoop: // from their PRNG. for i, p := range sesRun.peers { if p.ct == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } if len(recoveredCTs[i]) != len(p.ct.Ciphertexts) { - c.logf("blaming %x for different ciphertexts count %d != %d", + sesRun.logf("blaming %x for different ciphertexts count %d != %d", p.id[:], len(recoveredCTs[i]), len(p.ct.Ciphertexts)) blamed = append(blamed, *p.id) continue } for j := range p.ct.Ciphertexts { if !bytes.Equal(p.ct.Ciphertexts[j][:], recoveredCTs[i][j][:]) { - c.logf("blaming %x for different ciphertexts", p.id[:]) + sesRun.logf("blaming %x for different ciphertexts", p.id[:]) blamed = append(blamed, *p.id) break } @@ -294,7 +300,7 @@ KELoop: for _, pids := range shared { if len(pids) > 1 { for i := range pids { - c.logf("blaming %x for shared SR message", pids[i][:]) + sesRun.logf("blaming %x for shared SR message", pids[i][:]) } blamed = append(blamed, pids...) } @@ -306,7 +312,7 @@ KELoop: SRLoop: for i, p := range sesRun.peers { if p.sr == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -325,7 +331,7 @@ SRLoop: var decapErr *mixing.DecapsulateError if errors.As(err, &decapErr) { submittingID := p.id - c.logf("blaming %x for unrecoverable secrets", submittingID[:]) + sesRun.logf("blaming %x for unrecoverable secrets", submittingID[:]) blamed = append(blamed, *submittingID) continue } @@ -343,7 +349,7 @@ SRLoop: // Blame when committed mix does not match provided. for k := range srMix { if srMix[k].Cmp(scratch.SetBytes(p.sr.DCMix[j][k])) != 0 { - c.logf("blaming %x for bad SR mix", p.id[:]) + sesRun.logf("blaming %x for bad SR mix", p.id[:]) blamed = append(blamed, *p.id) continue SRLoop } @@ -376,7 +382,7 @@ DCLoop: // deferred function) if no peers could be assigned blame is // not likely to be seen under this situation. if p.dc == nil { - c.logf("blaming %x for missing messages", p.id[:]) + sesRun.logf("blaming %x for missing messages", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -386,7 +392,7 @@ DCLoop: // message, and there must be mcount DC-net vectors. mcount := p.pr.MessageCount if uint32(len(p.dc.DCNet)) != mcount { - c.logf("blaming %x for missing DC mix vectors", p.id[:]) + sesRun.logf("blaming %x for missing DC mix vectors", p.id[:]) blamed = append(blamed, *p.id) continue } @@ -406,7 +412,7 @@ DCLoop: // Blame when committed mix does not match provided. for k := 0; k < len(dcMix); k++ { if !dcMix.Equals(mixing.Vec(p.dc.DCNet[j])) { - c.logf("blaming %x for bad DC mix", p.id[:]) + sesRun.logf("blaming %x for bad DC mix", p.id[:]) blamed = append(blamed, *p.id) continue DCLoop } diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index 13b42353e9..bf8422afdc 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -64,7 +64,18 @@ const ( cmTimeout ) -func blameTimedOut(sesLog *sessionLogger, sesRun *sessionRun, timeoutMessage int) blamedIdentities { +// Constants specifying which peer messages to publish. +const ( + msgKE = 1 << iota + msgCT + msgSR + msgDC + msgFP + msgCM + msgRS +) + +func blameTimedOut(sesRun *sessionRun, timeoutMessage int) error { var blamed blamedIdentities var stage string for _, p := range sesRun.peers { @@ -91,8 +102,11 @@ func blameTimedOut(sesLog *sessionLogger, sesRun *sessionRun, timeoutMessage int } } } - sesLog.logf("blaming %x during run (%s timeout)", []identity(blamed), stage) - return blamed + if len(blamed) > 0 { + sesRun.logf("blaming %x during run (%s timeout)", []identity(blamed), stage) + return blamed + } + return errBlameFailed } // Wallet signs mix transactions and listens for and broadcasts mixing @@ -125,7 +139,6 @@ type Wallet interface { } type deadlines struct { - epoch time.Time recvKE time.Time recvCT time.Time recvSR time.Time @@ -146,41 +159,15 @@ func (d *deadlines) start(begin time.Time) { d.recvCM = add() } -func (d *deadlines) shift() { - d.recvKE = d.recvCT - d.recvCT = d.recvSR - d.recvSR = d.recvDC - d.recvDC = d.recvCM - d.recvCM = d.recvCM.Add(timeoutDuration) -} - -func (d *deadlines) restart() { - d.start(d.recvCM) -} - -// peer represents a participating client in a peer-to-peer mixing session. -// Some fields only pertain to peers created by this wallet, while the rest -// are used during blame assignment. -type peer struct { - ctx context.Context - client *Client - jitter time.Duration - - res chan error - - pub *secp256k1.PublicKey - priv *secp256k1.PrivateKey - id *identity // serialized pubkey - pr *wire.MsgMixPairReq - coinjoin *CoinJoin - kx *mixing.KX - +// peerRunState describes the peer state that changes across different +// sessions/runs. +type peerRunState struct { prngSeed [32]byte prng *chacha20prng.Reader + kx *mixing.KX // derived from PRNG - rs *wire.MsgMixSecrets - srMsg []*big.Int // random numbers for the exponential slot reservation mix - dcMsg wire.MixVect // anonymized messages to publish in XOR mix + srMsg []*big.Int // random (non-PRNG) numbers for the exponential slot reservation mix + dcMsg wire.MixVect // anonymized messages (HASH160s) to publish in XOR mix ke *wire.MsgMixKeyExchange ct *wire.MsgMixCiphertexts @@ -188,8 +175,9 @@ type peer struct { fp *wire.MsgMixFactoredPoly dc *wire.MsgMixDCNet cm *wire.MsgMixConfirm + rs *wire.MsgMixSecrets - // Unmixed positions. May change over multiple sessions/runs. + // Unmixed positions. myVk uint32 myStart uint32 @@ -204,12 +192,54 @@ type peer struct { // Whether peer misbehavior was detected by this peer, and initial // secrets will be revealed by c.blame(). triggeredBlame bool +} + +// peer represents a participating client in a peer-to-peer mixing session. +// Some fields only pertain to peers created by this wallet, while the rest +// are used during blame assignment. +type peer struct { + ctx context.Context + client *Client + jitter time.Duration + + res chan error + + pub *secp256k1.PublicKey + priv *secp256k1.PrivateKey + id *identity // serialized pubkey + pr *wire.MsgMixPairReq + coinjoin *CoinJoin + + peerRunState // Whether this peer represents a remote peer created from revealed secrets; // used during blaming. remote bool } +// cloneLocalPeer creates a new peer instance representing a local peer +// client, sharing the caller context and result channels, but resetting all +// per-session fields (if freshGen is true), or only copying the PRNG and +// fields directly derived from it (when freshGen is false). +func (p *peer) cloneLocalPeer(freshGen bool) *peer { + if p.remote { + panic("cloneLocalPeer: remote peer") + } + + p2 := *p + p2.peerRunState = peerRunState{} + + if !freshGen { + p2.prngSeed = p.prngSeed + p2.prng = p.prng + p2.kx = p.kx + p2.srMsg = p.srMsg + p2.dcMsg = p.dcMsg + } + + return &p2 +} + func newRemotePeer(pr *wire.MsgMixPairReq) *peer { return &peer{ id: &pr.Identity, @@ -229,32 +259,65 @@ func generateSecp256k1() (*secp256k1.PublicKey, *secp256k1.PrivateKey, error) { return publicKey, privateKey, nil } +type pendingPairing struct { + localPeers map[identity]*peer + pairing []byte +} + // pairedSessions tracks the waiting and in-progress mix sessions performed by // one or more local peers using compatible pairings. type pairedSessions struct { localPeers map[identity]*peer pairing []byte runs []sessionRun + + epoch time.Time + deadlines + + // Track state of pairing completion. + // + // An agreed-upon pairing means there was agreement on an initial set + // of peers/KEs. However, various situations (such as blaming + // misbehaving peers, or additional session formations after hitting + // size limits) must not consider the previous sessions with all + // received KEs. The peer agreement index clamps down on this by only + // considering runs in + // pairedSessions.runs[pairedSessions.peerAgreementRunIdx:] + peerAgreementRunIdx int + peerAgreement bool + + donePairingOnce sync.Once } type sessionRun struct { sid [32]byte + idx int // in pairedSessions.runs mtot uint32 // Whether this run must generate fresh KX keys, SR/DC messages. freshGen bool - deadlines - // Peers sorted by PR hashes. Each peer's myVk is its index in this // slice. - prs []*wire.MsgMixPairReq - peers []*peer - mcounts []uint32 - roots []*big.Int + localPeers map[identity]*peer + prs []*wire.MsgMixPairReq + kes []*wire.MsgMixKeyExchange // set by completePairing + peers []*peer + mcounts []uint32 + roots []*big.Int // Finalized coinjoin of a successful run. cj *CoinJoin + + logger slog.Logger +} + +func (s *sessionRun) logf(format string, args ...interface{}) { + if s.logger == nil { + return + } + + s.logger.Debugf("sid=%x/%d "+format, append([]interface{}{s.sid[:], s.idx}, args...)...) } type queueWork struct { @@ -273,9 +336,9 @@ type Client struct { // Pending and active sessions and peers (both local and, when // blaming, remote). - pairings map[string]*pairedSessions - height uint32 - mu sync.Mutex + pendingPairings map[string]*pendingPairing + height uint32 + mu sync.Mutex warming chan struct{} workQueue chan *queueWork @@ -303,15 +366,15 @@ func NewClient(w Wallet) *Client { height, _ := w.BestBlock() return &Client{ - atomicPRFlags: uint32(prFlags), - wallet: w, - mixpool: w.Mixpool(), - pairings: make(map[string]*pairedSessions), - warming: make(chan struct{}), - workQueue: make(chan *queueWork, runtime.NumCPU()), - blake256Hasher: blake256.New(), - epoch: w.Mixpool().Epoch(), - height: height, + atomicPRFlags: uint32(prFlags), + wallet: w, + mixpool: w.Mixpool(), + pendingPairings: make(map[string]*pendingPairing), + warming: make(chan struct{}), + workQueue: make(chan *queueWork, runtime.NumCPU()), + blake256Hasher: blake256.New(), + epoch: w.Mixpool().Epoch(), + height: height, } } @@ -343,23 +406,6 @@ func (c *Client) logerrf(format string, args ...interface{}) { c.logger.Errorf(format, args...) } -func (c *Client) sessionLog(sid [32]byte) *sessionLogger { - return &sessionLogger{sid: sid, logger: c.logger} -} - -type sessionLogger struct { - sid [32]byte - logger slog.Logger -} - -func (l *sessionLogger) logf(format string, args ...interface{}) { - if l.logger == nil { - return - } - - l.logger.Debugf("sid=%x "+format, append([]interface{}{l.sid[:]}, args...)...) -} - // Run runs the client manager, blocking until after the context is // cancelled. func (c *Client) Run(ctx context.Context) error { @@ -426,57 +472,109 @@ func (c *Client) forLocalPeers(ctx context.Context, s *sessionRun, f func(p *pee } type delayedMsg struct { - t time.Time - m mixing.Message - p *peer + sendTime time.Time + deadline time.Time + m mixing.Message + p *peer } -func (c *Client) sendLocalPeerMsgs(ctx context.Context, s *sessionRun, mayTriggerBlame bool, - m func(p *peer) mixing.Message) error { - - msgs := make([]delayedMsg, 0, len(s.peers)) - +func (c *Client) sendLocalPeerMsgs(ctx context.Context, deadline time.Time, s *sessionRun, msgMask uint) error { now := time.Now() + + msgs := make([]delayedMsg, 0, len(s.peers)*bits.OnesCount(msgMask)) for _, p := range s.peers { if p.remote || p.ctx.Err() != nil { continue } - msg := m(p) - if mayTriggerBlame && p.triggeredBlame { - msg = p.rs + msg := delayedMsg{ + sendTime: now.Add(p.msgJitter()), + deadline: deadline, + m: nil, + p: p, + } + msgMask := msgMask + if p.triggeredBlame { + msgMask |= msgRS + } + if msgMask&msgKE == msgKE && p.ke != nil { + msg.m = p.ke + msgs = append(msgs, msg) + } + if msgMask&msgCT == msgCT && p.ct != nil { + msg.m = p.ct + msgs = append(msgs, msg) + } + if msgMask&msgSR == msgSR && p.sr != nil { + msg.m = p.sr + msgs = append(msgs, msg) + } + if msgMask&msgFP == msgFP && p.fp != nil { + msg.m = p.fp + msgs = append(msgs, msg) + } + if msgMask&msgDC == msgDC && p.dc != nil { + msg.m = p.dc + msgs = append(msgs, msg) + } + if msgMask&msgCM == msgCM && p.cm != nil { + msg.m = p.cm + msgs = append(msgs, msg) + } + if msgMask&msgRS == msgRS && p.rs != nil { + msg.m = p.rs + msgs = append(msgs, msg) } - msgs = append(msgs, delayedMsg{ - t: now.Add(p.msgJitter()), - m: msg, - p: p, - }) } - sort.Slice(msgs, func(i, j int) bool { - return msgs[i].t.Before(msgs[j].t) + sort.SliceStable(msgs, func(i, j int) bool { + return msgs[i].sendTime.Before(msgs[j].sendTime) }) - resChans := make([]chan error, 0, len(s.peers)) + nilPeerMsg := func(p *peer, msg mixing.Message) { + switch msg.(type) { + case *wire.MsgMixKeyExchange: + p.ke = nil + case *wire.MsgMixCiphertexts: + p.ct = nil + case *wire.MsgMixSlotReserve: + p.sr = nil + case *wire.MsgMixFactoredPoly: + p.fp = nil + case *wire.MsgMixDCNet: + p.dc = nil + case *wire.MsgMixConfirm: + p.cm = nil + case *wire.MsgMixSecrets: + p.rs = nil + } + } + + errs := make([]error, 0, len(msgs)) + + var sessionCanceledState bool + sessionCanceled := func() { + if sessionCanceledState { + return + } + if err := ctx.Err(); err != nil { + err := fmt.Errorf("session cancelled: %w", err) + errs = append(errs, err) + sessionCanceledState = true + } + } + for i := range msgs { - res := make(chan error, 1) - resChans = append(resChans, res) m := msgs[i] - time.Sleep(time.Until(m.t)) - qsend := &queueWork{ - p: m.p, - f: func(p *peer) error { - return p.signAndSubmit(m.m) - }, - res: res, - } select { case <-ctx.Done(): - res <- ctx.Err() - case c.workQueue <- qsend: + sessionCanceled() + continue + case <-time.After(time.Until(m.sendTime)): + } + err := m.p.signAndSubmit(m.deadline, m.m) + if err != nil { + nilPeerMsg(m.p, m.m) + errs = append(errs, err) } - } - var errs = make([]error, len(resChans)) - for i := range errs { - errs[i] = <-resChans[i] } return errors.Join(errs...) } @@ -545,9 +643,9 @@ func (c *Client) testTick() { c.testTickC <- struct{}{} } -func (c *Client) testHook(stage hook, s *sessionRun, p *peer) { +func (c *Client) testHook(stage hook, ps *pairedSessions, s *sessionRun, p *peer) { if hook, ok := c.testHooks[stage]; ok { - hook(c, s, p) + hook(c, ps, s, p) } } @@ -564,11 +662,13 @@ func (p *peer) signAndHash(m mixing.Message) error { return nil } -func (p *peer) submit(m mixing.Message) error { - return p.client.wallet.SubmitMixMessage(p.ctx, m) +func (p *peer) submit(deadline time.Time, m mixing.Message) error { + ctx, cancel := context.WithDeadline(p.ctx, deadline) + defer cancel() + return p.client.wallet.SubmitMixMessage(ctx, m) } -func (p *peer) signAndSubmit(m mixing.Message) error { +func (p *peer) signAndSubmit(deadline time.Time, m mixing.Message) error { if m == nil { return nil } @@ -576,15 +676,22 @@ func (p *peer) signAndSubmit(m mixing.Message) error { if err != nil { return err } - return p.submit(m) + return p.submit(deadline, m) +} + +func (c *Client) newPendingPairing(pairing []byte) *pendingPairing { + return &pendingPairing{ + localPeers: make(map[identity]*peer), + pairing: pairing, + } } -func (c *Client) newPairings(pairing []byte, peers map[identity]*peer) *pairedSessions { - if peers == nil { - peers = make(map[identity]*peer) +func (c *Client) newPairedSessions(pairing []byte, pendingPeers map[identity]*peer) *pairedSessions { + if pendingPeers == nil { + pendingPeers = make(map[identity]*peer) } ps := &pairedSessions{ - localPeers: peers, + localPeers: pendingPeers, pairing: pairing, runs: nil, } @@ -655,38 +762,40 @@ func (c *Client) epochTicker(ctx context.Context) error { c.mixpool.RemoveConfirmedSessions() c.expireMessages() - for _, p := range c.pairings { + for _, p := range c.pendingPairings { prs := c.mixpool.CompatiblePRs(p.pairing) prsMap := make(map[identity]struct{}) for _, pr := range prs { prsMap[pr.Identity] = struct{}{} } - // Clone the p.localPeers map, only including PRs - // currently accepted to mixpool. Adding additional - // waiting local peers must not add more to the map in - // use by pairSession, and deleting peers in a formed - // session from the pending map must not inadvertently - // remove from pairSession's ps.localPeers map. + // Clone the pending peers map, only including peers + // with PRs currently accepted to mixpool. Adding + // additional waiting local peers must not add more to + // the map in use by pairSession, and deleting peers + // in a formed session from the pending map must not + // inadvertently remove from pairSession's peers map. localPeers := make(map[identity]*peer) for id, peer := range p.localPeers { if _, ok := prsMap[id]; ok { - localPeers[id] = peer + localPeers[id] = peer.cloneLocalPeer(true) } } + + // Even when we know there are not enough total peers + // to meet the minimum peer requirement, a run is + // still formed so that KEs can be broadcast; this + // must be done so peers can fetch missing PRs. + c.logf("Have %d compatible/%d local PRs waiting for pairing %x", len(prs), len(localPeers), p.pairing) - ps := *p - ps.localPeers = localPeers - for id, peer := range p.localPeers { - ps.localPeers[id] = peer - } // pairSession calls Done once the session is formed // and the selected peers have been removed from then // pending pairing. c.pairingWG.Add(1) - go c.pairSession(ctx, &ps, prs, epoch) + ps := c.newPairedSessions(p.pairing, localPeers) + go c.pairSession(ctx, ps, prs, epoch) } c.mu.Unlock() } @@ -742,20 +851,21 @@ func (c *Client) Dicemix(ctx context.Context, cj *CoinJoin) error { c.logf("Created local peer id=%x PR=%s", p.id[:], p.pr.Hash()) c.mu.Lock() - pairing := c.pairings[string(pairingID)] - if pairing == nil { - pairing = c.newPairings(pairingID, nil) - c.pairings[string(pairingID)] = pairing + pending := c.pendingPairings[string(pairingID)] + if pending == nil { + pending = c.newPendingPairing(pairingID) + c.pendingPairings[string(pairingID)] = pending } - pairing.localPeers[*p.id] = p + pending.localPeers[*p.id] = p c.mu.Unlock() - err = p.submit(pr) + deadline := time.Now().Add(timeoutDuration) + err = p.submit(deadline, pr) if err != nil { c.mu.Lock() - delete(pairing.localPeers, *p.id) - if len(pairing.localPeers) == 0 { - delete(c.pairings, string(pairingID)) + delete(pending.localPeers, *p.id) + if len(pending.localPeers) == 0 { + delete(c.pendingPairings, string(pairingID)) } c.mu.Unlock() return err @@ -786,28 +896,37 @@ func (c *Client) ExpireMessages(height uint32) { func (c *Client) expireMessages() { c.mixpool.ExpireMessages(c.height) - for pairID, ps := range c.pairings { - for id, p := range ps.localPeers { - prHash := p.pr.Hash() + for pairID, p := range c.pendingPairings { + for id, peer := range p.localPeers { + prHash := peer.pr.Hash() if !c.mixpool.HaveMessage(&prHash) { - delete(ps.localPeers, id) + delete(p.localPeers, id) // p.res is buffered. If the write is // blocked, we have already served this peer // or sent another error. select { - case p.res <- expiredPRErr(p.pr): + case peer.res <- expiredPRErr(peer.pr): default: } } } - if len(ps.localPeers) == 0 { - delete(c.pairings, pairID) + if len(p.localPeers) == 0 { + delete(c.pendingPairings, pairID) } } } +// donePairing decrements the client pairing waitgroup for the paired +// sessions. This is protected by a sync.Once and safe to call multiple +// times. +func (c *Client) donePairing(ps *pairedSessions) { + ps.donePairingOnce.Do(c.pairingWG.Done) +} + func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wire.MsgMixPairReq, epoch time.Time) { + defer c.donePairing(ps) + // This session pairing attempt, and calling pairSession again with // fresh PRs, must end before the next call to pairSession for this // pairing type. @@ -824,10 +943,9 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir defer func() { c.removeUnresponsiveDuringEpoch(unresponsive, unixEpoch) - unmixedPeers := ps.localPeers if mixedSession != nil && mixedSession.cj != nil { for _, pr := range mixedSession.prs { - delete(unmixedPeers, pr.Identity) + delete(ps.localPeers, pr.Identity) } // XXX: Removing these later in the background is a hack to @@ -836,218 +954,302 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir // for CM messages, which will increment ban score. go func() { time.Sleep(10 * time.Second) - c.logf("sid=%x removing mixed session completed with transaction %v", - mixedSession.sid[:], mixedSession.cj.txHash) + mixedSession.logf("removing mixed session completed "+ + "with transaction %v", mixedSession.cj.txHash) c.mixpool.RemoveSession(mixedSession.sid) }() } - if len(unmixedPeers) == 0 { + if len(ps.localPeers) == 0 { return } - for _, p := range unmixedPeers { - p.ke = nil - p.ct = nil - p.sr = nil - p.fp = nil - p.dc = nil - p.cm = nil - p.rs = nil - } - c.mu.Lock() - pendingPairing := c.pairings[string(ps.pairing)] + pendingPairing := c.pendingPairings[string(ps.pairing)] if pendingPairing == nil { - pendingPairing = c.newPairings(ps.pairing, unmixedPeers) - c.pairings[string(ps.pairing)] = pendingPairing + pendingPairing = c.newPendingPairing(ps.pairing) + c.pendingPairings[string(ps.pairing)] = pendingPairing } else { - for id, p := range unmixedPeers { + for id, p := range ps.localPeers { prHash := p.pr.Hash() if p.ctx.Err() == nil && c.mixpool.HaveMessage(&prHash) { - pendingPairing.localPeers[id] = p + pendingPairing.localPeers[id] = p.cloneLocalPeer(true) } } } c.mu.Unlock() }() - var madePairing bool - defer func() { - if !madePairing { - c.pairingWG.Done() - } - }() + ps.epoch = epoch + ps.deadlines.start(epoch) - var sesLog *sessionLogger - var currentRun *sessionRun - var rerun *sessionRun - var d deadlines - d.epoch = epoch - d.start(epoch) - for { - if rerun == nil { - sid := mixing.SortPRsForSession(prs, unixEpoch) - sesLog = c.sessionLog(sid) - - sesRun := sessionRun{ - sid: sid, - prs: prs, - freshGen: true, - deadlines: d, - mcounts: make([]uint32, 0, len(prs)), - } - ps.runs = append(ps.runs, sesRun) - currentRun = &ps.runs[len(ps.runs)-1] + sid := mixing.SortPRsForSession(prs, unixEpoch) + ps.runs = append(ps.runs, sessionRun{ + sid: sid, + prs: prs, + freshGen: true, + mcounts: make([]uint32, 0, len(prs)), + }) + newRun := &ps.runs[len(ps.runs)-1] - } else { - ps.runs = append(ps.runs, *rerun) - currentRun = &ps.runs[len(ps.runs)-1] - sesLog = c.sessionLog(currentRun.sid) - // rerun is not assigned nil here to please the - // linter. All code paths that reenter this loop will - // set it again. - } + for { + if newRun != nil { + newRun.idx = len(ps.runs) - 1 + newRun.logger = c.logger + + prs = newRun.prs + prHashes := make([]chainhash.Hash, len(prs)) + newRun.localPeers = make(map[identity]*peer) + var m uint32 + var localPeerCount, localUncancelledCount int + for i, pr := range prs { + prHashes[i] = prs[i].Hash() + + // Peer clones must be made from the previous run's + // local peer objects (if any) to preserve existing + // PRNGs and derived secrets and keys. + peerMap := ps.localPeers + if newRun.idx > 0 { + peerMap = ps.runs[newRun.idx-1].localPeers + } + p := peerMap[pr.Identity] + if p != nil { + p = p.cloneLocalPeer(newRun.freshGen) + localPeerCount++ + if p.ctx.Err() == nil { + localUncancelledCount++ + } + newRun.localPeers[*p.id] = p + } else { + p = newRemotePeer(pr) + } + p.myVk = uint32(i) + p.myStart = m - prs = currentRun.prs - prHashes := make([]chainhash.Hash, len(prs)) - for i := range prs { - prHashes[i] = prs[i].Hash() - } + newRun.peers = append(newRun.peers, p) + newRun.mcounts = append(newRun.mcounts, p.pr.MessageCount) - var m uint32 - var localPeerCount, localUncancelledCount int - for i, pr := range prs { - p := ps.localPeers[pr.Identity] - if p != nil { - localPeerCount++ - if p.ctx.Err() == nil { - localUncancelledCount++ - } - } else { - p = newRemotePeer(pr) + m += p.pr.MessageCount } - p.myVk = uint32(i) - p.myStart = m + newRun.mtot = m - currentRun.peers = append(currentRun.peers, p) - currentRun.mcounts = append(currentRun.mcounts, p.pr.MessageCount) + newRun.logf("created session for pairid=%x from %d total %d local PRs %s", + ps.pairing, len(prHashes), localPeerCount, prHashes) - m += p.pr.MessageCount - } - currentRun.mtot = m - - sesLog.logf("created session for pairid=%x from %d total %d local PRs %s", - ps.pairing, len(prHashes), localPeerCount, prHashes) + if localUncancelledCount == 0 { + newRun.logf("no more active local peers") + return + } - if localUncancelledCount == 0 { - sesLog.logf("no more active local peers") - return + newRun = nil } - sesLog.logf("len(ps.runs)=%d", len(ps.runs)) - - c.testHook(hookBeforeRun, currentRun, nil) - err := c.run(ctx, ps, &madePairing) + c.testHook(hookBeforeRun, ps, &ps.runs[len(ps.runs)-1], nil) + r, err := c.run(ctx, ps) + var rerun *sessionRun + var altses *alternateSession var sizeLimitedErr *sizeLimited - if errors.As(err, &sizeLimitedErr) { - if len(sizeLimitedErr.prs) < MinPeers { - sesLog.logf("Aborting session with too few remaining peers") - return - } - - d.shift() + var blamed blamedIdentities + var revealedSecrets bool + var requirePeerAgreement bool + switch { + case errors.Is(err, errOnlyKEsBroadcasted): + // When only KEs are broadcasted, the session was not viable + // due to lack of peers or a peer capable of solving the + // roots. Return without setting the mixed session. + return - sesLog.logf("Recreating as session %x due to standard tx size limits (pairid=%x)", - sizeLimitedErr.sid[:], ps.pairing) + case errors.As(err, &altses): + // If this errored or has too few peers, keep + // retrying previous attempts until next epoch, + // instead of just going away. + if altses.err != nil { + r.logf("Unable to recreate session: %v", altses.err) + ps.deadlines.start(time.Now()) + continue + } - rerun = &sessionRun{ - sid: sizeLimitedErr.sid, - prs: sizeLimitedErr.prs, - freshGen: false, - deadlines: d, + if r.sid != altses.sid { + r.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) + unresponsive = append(unresponsive, altses.unresponsive...) } - continue - } - var altses *alternateSession - if errors.As(err, &altses) { - if altses.err != nil { - sesLog.logf("Unable to recreate session: %v", altses.err) - return + // Required minimum run index is not incremented for + // reformed sessions without peer agreement: we must + // consider receiving KEs from peers who reformed into + // the same session we previously attempted. + rerun = &sessionRun{ + sid: altses.sid, + prs: altses.prs, + freshGen: false, } + err = nil - if len(altses.prs) < MinPeers { - sesLog.logf("Aborting session with too few remaining peers") + case errors.As(err, &sizeLimitedErr): + if len(sizeLimitedErr.prs) < MinPeers { + r.logf("Aborting session with too few remaining peers") return } - d.shift() - - if currentRun.sid != altses.sid { - sesLog.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) - unresponsive = append(unresponsive, altses.unresponsive...) - } + r.logf("Recreating as session %x due to standard tx size limits (pairid=%x)", + sizeLimitedErr.sid[:], ps.pairing) rerun = &sessionRun{ - sid: altses.sid, - prs: altses.prs, - freshGen: false, - deadlines: d, + sid: sizeLimitedErr.sid, + prs: sizeLimitedErr.prs, + freshGen: false, } - continue - } + requirePeerAgreement = true + err = nil - var blamed blamedIdentities - revealedSecrets := false - if errors.Is(err, errTriggeredBlame) || errors.Is(err, mixpool.ErrSecretsRevealed) { + case errors.Is(err, errTriggeredBlame) || errors.Is(err, mixpool.ErrSecretsRevealed): revealedSecrets = true - err := c.blame(ctx, currentRun) + err := c.blame(ctx, r) if !errors.As(err, &blamed) { - sesLog.logf("Aborting session for failed blame assignment: %v", err) + r.logf("Aborting session for failed blame assignment: %v", err) return } + requirePeerAgreement = true + // err = nil would be an ineffectual assignment here; + // blamed is non-nil and the following if block will + // always be entered. } + if blamed != nil || errors.As(err, &blamed) { - sesLog.logf("Identified %d blamed peers %x", len(blamed), []identity(blamed)) + r.logf("Identified %d blamed peers %x", len(blamed), []identity(blamed)) + + if len(r.prs)-len(blamed) < MinPeers { + r.logf("Aborting session with too few remaining peers") + return + } // Blamed peers were identified, either during the run // in a way that all participants could have observed, // or following revealing secrets and blame // assignment. Begin a rerun excluding these peers. - rerun = excludeBlamed(currentRun, blamed, revealedSecrets) + rerun = excludeBlamed(r, unixEpoch, blamed, revealedSecrets) + requirePeerAgreement = true + err = nil + } + + if rerun != nil { + if ps.runs[len(ps.runs)-1].sid == rerun.sid { + r := &ps.runs[len(ps.runs)-1] + r.logf("recreated session matches previous try; " + + "not creating new run states") + if ps.peerAgreementRunIdx >= len(ps.runs) { + // XXX shouldn't happen but fix up anyways + r.logf("reverting incremented peer agreement index") + ps.peerAgreementRunIdx = r.idx + } + } else { + ps.runs = append(ps.runs, *rerun) + newRun = &ps.runs[len(ps.runs)-1] + } + ps.deadlines.start(time.Now()) + if requirePeerAgreement { + ps.peerAgreementRunIdx = len(ps.runs) - 1 + } continue } - // When only KEs are broadcasted, the session was not viable - // due to lack of peers or a peer capable of solving the - // roots. Return without setting the mixed session. - if errors.Is(err, errOnlyKEsBroadcasted) { - return - } - // Any other run error is not actionable. if err != nil { - sesLog.logf("Run error: %v", err) + r.logf("Run error: %v", err) return } - mixedSession = currentRun + mixedSession = r return } } -func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) error { +var errIncompletePairing = errors.New("incomplete pairing") + +// completePairing waits for all KEs to form a completed session. Completed +// pairings are checked for in the order the sessions were attempted. +func (c *Client) completePairing(ctx context.Context, ps *pairedSessions) (*sessionRun, error) { + mp := c.mixpool + res := make(chan *sessionRun, len(ps.runs)) + errs := make(chan error, len(ps.runs)) + var errCount int + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + wrappedCtx, cancel := context.WithCancel(ctx) + defer cancel() + + recvKEs := func(ctx context.Context, sesRun *sessionRun) ([]*wire.MsgMixKeyExchange, error) { + rcv := new(mixpool.Received) + rcv.Sid = sesRun.sid + rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) + ctx, cancel := context.WithDeadline(ctx, ps.deadlines.recvKE) + defer cancel() + err := mp.Receive(ctx, rcv) + if len(rcv.KEs) == len(sesRun.prs) { + return rcv.KEs, nil + } + if err == nil { + err = errIncompletePairing + } + return nil, err + } + + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + + for i := ps.peerAgreementRunIdx; i < len(ps.runs); i++ { + sr := &ps.runs[i] + + // Initially check the mempool with the pre-canceled context + // to return immediately with KEs received so far. + kes, err := recvKEs(canceledCtx, sr) + if err == nil { + sr.kes = kes + return sr, nil + } + + wg.Add(1) + go func() { + defer wg.Done() + kes, err := recvKEs(wrappedCtx, sr) + if err == nil { + sr.kes = kes + res <- sr + } else { + errs <- err + } + }() + } + + for { + select { + case sesRun := <-res: + return sesRun, nil + case err := <-errs: + errCount++ + if errCount == len(ps.runs[ps.peerAgreementRunIdx:]) { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (c *Client) run(ctx context.Context, ps *pairedSessions) (sesRun *sessionRun, err error) { var blamed blamedIdentities mp := c.wallet.Mixpool() - sesRun := &ps.runs[len(ps.runs)-1] + sesRun = &ps.runs[len(ps.runs)-1] prs := sesRun.prs - d := &sesRun.deadlines - unixEpoch := uint64(d.epoch.Unix()) - - sesLog := c.sessionLog(sesRun.sid) + d := &ps.deadlines + unixEpoch := uint64(ps.epoch.Unix()) // A map of identity public keys to their PR position sort all // messages in the same order as the PRs are ordered. @@ -1056,15 +1258,11 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) identityIndices[pr.Identity] = i } - seenPRs := make([]chainhash.Hash, len(prs)) - for i := range prs { - seenPRs[i] = prs[i].Hash() - } - - err := c.forLocalPeers(ctx, sesRun, func(p *peer) error { - p.coinjoin.resetUnmixed(prs) + freshGen := sesRun.freshGen + err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if freshGen { + p.ke = nil - if sesRun.freshGen { // Generate a new PRNG seed rand.Read(p.prngSeed[:]) p.prng = chacha20prng.New(p.prngSeed[:], 0) @@ -1102,45 +1300,54 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } - // Perform key exchange - srMsgBytes := make([][]byte, len(p.srMsg)) - for i := range p.srMsg { - srMsgBytes[i] = p.srMsg[i].Bytes() - } - rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0, - p.prngSeed, srMsgBytes, p.dcMsg) - c.blake256HasherMu.Lock() - commitment := rs.Commitment(c.blake256Hasher) - c.blake256HasherMu.Unlock() - ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) - pqPub := *p.kx.PQPublicKey - ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0, - uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment, - seenPRs) - - p.ke = ke - p.rs = rs + if p.ke == nil { + seenPRs := make([]chainhash.Hash, len(prs)) + for i := range prs { + seenPRs[i] = prs[i].Hash() + } + + // Perform key exchange + srMsgBytes := make([][]byte, len(p.srMsg)) + for i := range p.srMsg { + srMsgBytes[i] = p.srMsg[i].Bytes() + } + rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0, + p.prngSeed, srMsgBytes, p.dcMsg) + c.blake256HasherMu.Lock() + commitment := rs.Commitment(c.blake256Hasher) + c.blake256HasherMu.Unlock() + ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) + pqPub := *p.kx.PQPublicKey + ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0, + uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment, + seenPRs) + + p.ke = ke + p.ct = nil + p.sr = nil + p.fp = nil + p.dc = nil + p.cm = nil + p.rs = rs + } return nil }) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) + } else { + sesRun.freshGen = false } - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.ke == nil { - return nil - } - return p.ke - }) + err = c.sendLocalPeerMsgs(ctx, ps.deadlines.recvKE, sesRun, msgKE) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } // Only continue attempting to form the session if there are minimum // peers available and at least one of them is capable of solving the // roots. if len(prs) < MinPeers { - sesLog.logf("Pairing %x: minimum peer requirement unmet", ps.pairing) - return errOnlyKEsBroadcasted + sesRun.logf("pairing %x: minimum peer requirement unmet", ps.pairing) + return sesRun, errOnlyKEsBroadcasted } haveSolver := false for _, pr := range prs { @@ -1150,8 +1357,8 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if !haveSolver { - sesLog.logf("Pairing %x: no solver available", ps.pairing) - return errOnlyKEsBroadcasted + sesRun.logf("pairing %x: no solver available", ps.pairing) + return sesRun, errOnlyKEsBroadcasted } // Receive key exchange messages. @@ -1168,57 +1375,51 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // All KE messages that match the pairing ID are received, and each // seen PRs slice is checked. PRs that were never followed up by a KE // are immediately excluded. - var kes []*wire.MsgMixKeyExchange - recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) { - rcv := new(mixpool.Received) - rcv.Sid = sesRun.sid - rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) - ctx, cancel := context.WithDeadline(ctx, d.recvKE) - defer cancel() - err = mp.Receive(ctx, rcv) - if ctx.Err() != nil { - err = fmt.Errorf("session %x KE receive context cancelled: %w", - sesRun.sid[:], ctx.Err()) - } - return rcv.KEs, err - } - - switch { - case !*madePairing: - // Receive KEs for the last attempted session. Local - // peers may have been modified (new keys generated, and myVk - // indexes changed) if this is a recreated session, and we - // cannot continue mix using these messages. - // - // XXX: do we need to keep peer info available for previous - // session attempts? It is possible that previous sessions - // may be operable now if all wallets have come to agree on a - // previous session we also tried to form. - kes, err = recvKEs(sesRun) - if err == nil && len(kes) == len(sesRun.prs) { - break - } - - // Alternate session needs to be attempted. Do not form an + completedSesRun, err := c.completePairing(ctx, ps) + if err != nil { + // Alternate session may need to be attempted. Do not form an // alternate session if we are about to enter into the next // epoch. The session forming will be performed by a new // goroutine started by the epoch ticker, possibly with // additional PRs. - nextEpoch := d.epoch.Add(c.epoch) + nextEpoch := ps.epoch.Add(c.epoch) if time.Now().Add(timeoutDuration).After(nextEpoch) { c.logf("Aborting session %x after %d attempts", sesRun.sid[:], len(ps.runs)) - return errOnlyKEsBroadcasted + return sesRun, errOnlyKEsBroadcasted } - return c.alternateSession(ps.pairing, sesRun.prs, d) + // If peer agreement was never established, alternate sessions + // based on the seen PRs must be formed. + if !ps.peerAgreement { + return sesRun, c.alternateSession(ps, sesRun.prs) + } - default: - kes, err = recvKEs(sesRun) - if err != nil { - return err + return sesRun, err + } + + if completedSesRun != sesRun { + completedSesRun.logf("replacing previous session attempt %x/%d", + sesRun.sid[:], sesRun.idx) + + // Reset variables created from assuming the + // final session run. + sesRun = completedSesRun + prs = sesRun.prs + identityIndices = make(map[identity]int) + for i, pr := range prs { + identityIndices[pr.Identity] = i } } + kes := sesRun.kes + sesRun.logf("received all %d KEs", len(kes)) + + // The coinjoin structure is commonly referenced by all instances of + // the local peers; reset all of them from the initial paired session + // attempt, and not only those in the last session run. + for _, p := range ps.localPeers { + p.coinjoin.resetUnmixed(prs) + } // Before confirming the pairing, check all of the agreed-upon PRs // that they will not result in a coinjoin transaction that exceeds @@ -1227,7 +1428,10 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // PRs are randomly ordered in each epoch based on the session ID, so // they can be iterated in order to discover any PR that would // increase the final coinjoin size above the limits. - if !*madePairing { + if !ps.peerAgreement { + ps.peerAgreement = true + ps.peerAgreementRunIdx = sesRun.idx + var sizeExcluded []*wire.MsgMixPairReq var cjSize coinjoinSize for _, pr := range sesRun.prs { @@ -1256,7 +1460,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } sid := mixing.SortPRsForSession(kept, unixEpoch) - return &sizeLimited{ + return sesRun, &sizeLimited{ prs: kept, sid: sid, excluded: sizeExcluded, @@ -1270,22 +1474,23 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } - // Remove paired local peers from waiting pairing. - if !*madePairing { - c.mu.Lock() - if waiting := c.pairings[string(ps.pairing)]; waiting != nil { - for id := range ps.localPeers { - delete(waiting.localPeers, id) - } - if len(waiting.localPeers) == 0 { - delete(c.pairings, string(ps.pairing)) - } + // Remove paired local peers from pending pairings. + // + // XXX might want to keep these instead of racing to add them back if + // this mix doesn't run to completion, and we start next epoch without + // some of our own peers. + c.mu.Lock() + if pending := c.pendingPairings[string(ps.pairing)]; pending != nil { + for id := range sesRun.localPeers { + delete(pending.localPeers, id) + } + if len(pending.localPeers) == 0 { + delete(c.pendingPairings, string(ps.pairing)) } - c.mu.Unlock() - - *madePairing = true - c.pairingWG.Done() } + c.mu.Unlock() + + c.donePairing(ps) sort.Slice(kes, func(i, j int) bool { a := identityIndices[kes[i].Identity] @@ -1299,13 +1504,13 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) involvesLocalPeers := false for _, ke := range kes { - if ps.localPeers[ke.Identity] != nil { + if sesRun.localPeers[ke.Identity] != nil { involvesLocalPeers = true break } } if !involvesLocalPeers { - return errors.New("excluded all local peers") + return sesRun, errors.New("excluded all local peers") } ecdhPublicKeys := make([]*secp256k1.PublicKey, 0, len(prs)) @@ -1320,11 +1525,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) pqpk = append(pqpk, &ke.PQPK) } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (invalid ECDH pubkeys)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (invalid ECDH pubkeys)", []identity(blamed)) + return sesRun, blamed } err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.ct != nil { + return nil + } + // Create shared keys and ciphextexts for each peer pqct, err := p.kx.Encapsulate(p.prng, pqpk, int(p.myVk)) if err != nil { @@ -1334,20 +1543,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Send ciphertext messages ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, 0, pqct, seenKEs) p.ct = ct - c.testHook(hookBeforePeerCTPublish, sesRun, p) + c.testHook(hookBeforePeerCTPublish, ps, sesRun, p) return nil }) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } - err = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.ct == nil { - return nil - } - return p.ct - }) + err = c.sendLocalPeerMsgs(ctx, d.recvCT, sesRun, msgCT) if err != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } // Receive all ciphertext messages @@ -1365,12 +1569,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(cts) != len(prs) { // Blame peers - sesLog.logf("Received %d CTs for %d peers; rerunning", len(cts), len(prs)) - return blameTimedOut(sesLog, sesRun, ctTimeout) + sesRun.logf("received %d CTs for %d peers; rerunning", len(cts), len(prs)) + return sesRun, blameTimedOut(sesRun, ctTimeout) } sort.Slice(cts, func(i, j int) bool { a := identityIndices[cts[i].Identity] @@ -1385,6 +1589,10 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) blamedMap := make(map[identity]struct{}) var blamedMapMu sync.Mutex err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.sr != nil { + return nil + } + revealed := &mixing.RevealedKeys{ ECDHPublicKeys: ecdhPublicKeys, Ciphertexts: make([]mixing.PQCiphertext, 0, len(prs)), @@ -1422,28 +1630,23 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // reservations. sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, 0, srMixBytes, seenCTs) p.sr = sr - c.testHook(hookBeforePeerSRPublish, sesRun, p) + c.testHook(hookBeforePeerSRPublish, ps, sesRun, p) return nil }) if len(blamedMap) > 0 { for id := range blamedMap { blamed = append(blamed, id) } - sesLog.logf("blaming %x during run (wrong ciphertext count)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (wrong ciphertext count)", []identity(blamed)) + return sesRun, blamed } - sendErr := c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.sr == nil { - return nil - } - return p.sr - }) + sendErr := c.sendLocalPeerMsgs(ctx, d.recvSR, sesRun, msgSR) if sendErr != nil { - sesLog.logf("%v", sendErr) + sesRun.logf("%v", sendErr) } if err != nil { - sesLog.logf("%v", err) - return err + sesRun.logf("%v", err) + return sesRun, err } // Receive all slot reservation messages @@ -1459,12 +1662,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(srs) != len(prs) { // Blame peers - sesLog.logf("Received %d SRs for %d peers; rerunning", len(srs), len(prs)) - return blameTimedOut(sesLog, sesRun, srTimeout) + sesRun.logf("received %d SRs for %d peers; rerunning", len(srs), len(prs)) + return sesRun, blameTimedOut(sesRun, srTimeout) } sort.Slice(srs, func(i, j int) bool { a := identityIndices[srs[i].Identity] @@ -1484,16 +1687,18 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) powerSums := mixing.AddVectors(mixing.IntVectorsFromBytes(vs)...) coeffs := mixing.Coefficients(powerSums) rcvCtx, rcvCtxCancel = context.WithDeadline(ctx, d.recvSR) - publishedRootsC := make(chan struct{}) - defer func() { <-publishedRootsC }() - roots, err := c.roots(rcvCtx, seenSRs, sesRun, coeffs, mixing.F, publishedRootsC) + roots, err := c.roots(rcvCtx, seenSRs, sesRun, coeffs, mixing.F) rcvCtxCancel() if err != nil { - return err + return sesRun, err } sesRun.roots = roots err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.dc != nil { + return nil + } + // Find reserved slots slots := make([]uint32, 0, p.pr.MessageCount) for _, m := range p.srMsg { @@ -1516,21 +1721,16 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Broadcast XOR DC-net vectors. dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, 0, p.dcNet, seenSRs) p.dc = dc - c.testHook(hookBeforePeerDCPublish, sesRun, p) + c.testHook(hookBeforePeerDCPublish, ps, sesRun, p) return nil }) - sendErr = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.dc == nil { - return nil - } - return p.dc - }) + sendErr = c.sendLocalPeerMsgs(ctx, d.recvDC, sesRun, msgFP|msgDC) if sendErr != nil { - sesLog.logf("%v", err) + sesRun.logf("%v", err) } if err != nil { - sesLog.logf("DC-net error: %v", err) - return err + sesRun.logf("DC-net error: %v", err) + return sesRun, err } // Receive all DC messages @@ -1546,12 +1746,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(dcs) != len(prs) { // Blame peers - sesLog.logf("Received %d DCs for %d peers; rerunning", len(dcs), len(prs)) - return blameTimedOut(sesLog, sesRun, dcTimeout) + sesRun.logf("received %d DCs for %d peers; rerunning", len(dcs), len(prs)) + return sesRun, blameTimedOut(sesRun, dcTimeout) } sort.Slice(dcs, func(i, j int) bool { a := identityIndices[dcs[i].Identity] @@ -1575,11 +1775,15 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (wrong DC-net count)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (wrong DC-net count)", []identity(blamed)) + return sesRun, blamed } mixedMsgs := mixing.XorVectors(dcVecs) err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.cm != nil { + return nil + } + // Add outputs for each mixed message for i := range mixedMsgs { mixedMsg := mixedMsgs[i][:] @@ -1591,7 +1795,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // provided inputs. err := p.coinjoin.confirm(c.wallet) if errors.Is(err, errMissingGen) { - sesLog.logf("Missing message; blaming and rerunning") + sesRun.logf("missing message; blaming and rerunning") p.triggeredBlame = true return errTriggeredBlame } @@ -1606,18 +1810,13 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) p.cm = cm return nil }) - sendErr = c.sendLocalPeerMsgs(ctx, sesRun, true, func(p *peer) mixing.Message { - if p.cm == nil { - return nil - } - return p.cm - }) + sendErr = c.sendLocalPeerMsgs(ctx, d.recvCM, sesRun, msgCM) if sendErr != nil { - sesLog.logf("%v", sendErr) + sesRun.logf("%v", sendErr) } if err != nil { - sesLog.logf("Confirm error: %v", err) - return err + sesRun.logf("confirm error: %v", err) + return sesRun, err } // Receive all CM messages @@ -1633,12 +1832,12 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if err != nil { - return err + return sesRun, err } if len(cms) != len(prs) { // Blame peers - sesLog.logf("Received %d CMs for %d peers; rerunning", len(cms), len(prs)) - return blameTimedOut(sesLog, sesRun, cmTimeout) + sesRun.logf("received %d CMs for %d peers; rerunning", len(cms), len(prs)) + return sesRun, blameTimedOut(sesRun, cmTimeout) } sort.Slice(cms, func(i, j int) bool { a := identityIndices[cms[i].Identity] @@ -1670,18 +1869,18 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } } if len(blamed) > 0 { - sesLog.logf("blaming %x during run (confirmed wrong coinjoin)", []identity(blamed)) - return blamed + sesRun.logf("blaming %x during run (confirmed wrong coinjoin)", []identity(blamed)) + return sesRun, blamed } err = c.validateMergedCoinjoin(cj, prs, utxos) if err != nil { - return err + return sesRun, err } time.Sleep(lowestJitter + rand.Duration(msgJitter)) err = c.wallet.PublishTransaction(context.Background(), cj.tx) if err != nil { - return err + return sesRun, err } c.forLocalPeers(ctx, sesRun, func(p *peer) error { @@ -1694,7 +1893,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) sesRun.cj = cj - return nil + return sesRun, nil } func (c *Client) solvesRoots() bool { @@ -1707,7 +1906,7 @@ func (c *Client) solvesRoots() bool { // the result to all peers. If the client is incapable of solving the roots, // it waits for a solution. func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, - sesRun *sessionRun, a []*big.Int, F *big.Int, publishedRoots chan struct{}) ([]*big.Int, error) { + sesRun *sessionRun, a []*big.Int, F *big.Int) ([]*big.Int, error) { switch { case c.solvesRoots(): @@ -1728,7 +1927,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, p.triggeredBlame = true return nil }) - close(publishedRoots) return nil, errTriggeredBlame } sort.Slice(roots, func(i, j int) bool { @@ -1739,25 +1937,16 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, rootBytes[i] = root.Bytes() } c.forLocalPeers(ctx, sesRun, func(p *peer) error { + if p.fp != nil { + return nil + } p.fp = wire.NewMsgMixFactoredPoly(*p.id, sesRun.sid, 0, rootBytes, seenSRs) return nil }) - // Don't wait for these messages to send. - go func() { - err := c.sendLocalPeerMsgs(ctx, sesRun, false, func(p *peer) mixing.Message { - return p.fp - }) - if ctx.Err() == nil && err != nil { - c.logf("sid=%x %v", sesRun.sid[:], err) - } - publishedRoots <- struct{}{} - }() return roots, nil } - close(publishedRoots) - // Clients unable to solve their own roots must wait for solutions. // We can return a result as soon as we read any valid factored // polynomial message that provides the solutions for this SR @@ -1780,6 +1969,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, if _, ok := checkedFPByIdentity[fp.Identity]; ok { continue } + checkedFPByIdentity[fp.Identity] = struct{}{} roots = roots[:0] duplicateRoots := make(map[string]struct{}) @@ -1814,8 +2004,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, if len(roots) == len(a)-1 { return roots, nil } - - checkedFPByIdentity[fp.Identity] = struct{}{} } rcv.FPs = make([]*wire.MsgMixFactoredPoly, 0, len(rcv.FPs)+1) @@ -1928,10 +2116,10 @@ func (e *alternateSession) Unwrap() error { return e.err } -func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, d *deadlines) *alternateSession { - unixEpoch := uint64(d.epoch.Unix()) +func (c *Client) alternateSession(ps *pairedSessions, prs []*wire.MsgMixPairReq) *alternateSession { + unixEpoch := uint64(ps.epoch.Unix()) - kes := c.mixpool.ReceiveKEsByPairing(pairing, unixEpoch) + kes := c.mixpool.ReceiveKEsByPairing(ps.pairing, unixEpoch) // Sort KEs by identity first (just to group these together) followed // by the total referenced PR counts in increasing order (most recent @@ -2112,7 +2300,7 @@ func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, d * } -func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets bool) *sessionRun { +func excludeBlamed(prevRun *sessionRun, epoch uint64, blamed blamedIdentities, revealedSecrets bool) *sessionRun { blamedMap := make(map[identity]struct{}) for _, id := range blamed { blamedMap[id] = struct{}{} @@ -2132,19 +2320,15 @@ func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets prs = append(prs, p.pr) } - d := prevRun.deadlines - d.restart() - - unixEpoch := prevRun.epoch.Unix() - sid := mixing.SortPRsForSession(prs, uint64(unixEpoch)) + sid := mixing.SortPRsForSession(prs, epoch) // mtot, peers, mcounts are all recalculated from the prs before // calling run() nextRun := &sessionRun{ - sid: sid, - freshGen: revealedSecrets, - prs: prs, - deadlines: d, + sid: sid, + idx: prevRun.idx + 1, + freshGen: revealedSecrets, + prs: prs, } return nextRun } diff --git a/mixing/mixclient/client_test.go b/mixing/mixclient/client_test.go index a267b70605..1f4d88d0e0 100644 --- a/mixing/mixclient/client_test.go +++ b/mixing/mixclient/client_test.go @@ -267,8 +267,8 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { w := newTestWallet(bc) c := newTestClient(w, l) c.testHooks = map[hook]hookFunc{ - hookBeforeRun: func(c *Client, s *sessionRun, _ *peer) { - s.deadlines.recvKE = time.Now().Add(5 * time.Second) + hookBeforeRun: func(c *Client, ps *pairedSessions, _ *sessionRun, _ *peer) { + ps.deadlines.recvKE = time.Now().Add(5 * time.Second) }, h: f, } @@ -397,73 +397,78 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { func TestCTDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping CT bit", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts[1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping CT bit", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts[1][0] ^= 1 + }) } func TestCTLength(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1] - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1] + }) misbehavingID = identity{} - testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:]) - misbehavingID = *p.id - p.ct.Ciphertexts = append(p.ct.Ciphertexts, p.ct.Ciphertexts[0]) - }) + testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:]) + misbehavingID = *p.id + p.ct.Ciphertexts = append(p.ct.Ciphertexts, p.ct.Ciphertexts[0]) + }) } func TestSRDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping SR bit", p.id[:]) - misbehavingID = *p.id - p.sr.DCMix[0][1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping SR bit", p.id[:]) + misbehavingID = *p.id + p.sr.DCMix[0][1][0] ^= 1 + }) } func TestDCDisruption(t *testing.T) { var misbehavingID identity - testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, func(c *Client, s *sessionRun, p *peer) { - if p.myVk != 0 { - return - } - if misbehavingID != [33]byte{} { - return - } - t.Logf("malicious peer %x: flipping DC bit", p.id[:]) - misbehavingID = *p.id - p.dc.DCNet[0][1][0] ^= 1 - }) + testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, + func(c *Client, ps *pairedSessions, s *sessionRun, p *peer) { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { + return + } + t.Logf("malicious peer %x: flipping DC bit", p.id[:]) + misbehavingID = *p.id + p.dc.DCNet[0][1][0] ^= 1 + }) } diff --git a/mixing/mixclient/testhooks.go b/mixing/mixclient/testhooks.go index de19db94d3..d564e2e3ce 100644 --- a/mixing/mixclient/testhooks.go +++ b/mixing/mixclient/testhooks.go @@ -6,7 +6,7 @@ package mixclient type hook string -type hookFunc func(*Client, *sessionRun, *peer) +type hookFunc func(*Client, *pairedSessions, *sessionRun, *peer) const ( hookBeforeRun hook = "before run"