From 836cb38a84a2229e34e2ca65f8d88c3c5d5ecd24 Mon Sep 17 00:00:00 2001 From: Yusuke Hata Date: Tue, 19 Jul 2022 19:06:43 +0900 Subject: [PATCH] retrieve behind data when reconnect --- repli/stream.go | 81 +++++++++++++++++++++++++++++++++++------ repli/stream_test.go | 3 -- repli_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 13 deletions(-) diff --git a/repli/stream.go b/repli/stream.go index 52b2767..0e5bdd5 100644 --- a/repli/stream.go +++ b/repli/stream.go @@ -152,6 +152,7 @@ type streamEmitter struct { releaseTTL time.Duration emitApplyCh chan emitApply done chan struct{} + closed bool src Source server *server.Server emitConn *nats.Conn @@ -320,10 +321,23 @@ func (e *streamEmitter) applyCurrentFileID(fileID datafile.FileID) { e.emitConn.Flush() } +func (e *streamEmitter) reconnectEmitter(conn *nats.Conn) { + e.logger.Printf("warn: reconnected emitter: %s", conn.ConnectedUrl()) +} + +func (e *streamEmitter) reconnectReply(conn *nats.Conn) { + e.logger.Printf("warn: reconnected replier: %s", conn.ConnectedUrl()) +} + func (e *streamEmitter) Start(src Source, bindIP string, bindPort int) error { e.mutex.Lock() defer e.mutex.Unlock() + if e.closed { + // maybe restart + e.done = make(chan struct{}) + } + go e.runReleaseLoop() go e.emitLoop() @@ -350,12 +364,12 @@ func (e *streamEmitter) Start(src Source, bindIP string, bindPort int) error { } natsUrl := fmt.Sprintf("nats://%s", svr.Addr().String()) - emitConn, err := conn(natsUrl, "emitter") + emitConn, err := conn(natsUrl, "emitter", e.reconnectEmitter) if err != nil { return errors.Wrapf(err, "nats emitter connect: %s", natsUrl) } - replyConn, err := conn(natsUrl, "reply") + replyConn, err := conn(natsUrl, "reply", e.reconnectReply) if err != nil { return errors.Wrapf(err, "nats reply connect: %s", natsUrl) } @@ -396,10 +410,18 @@ func (e *streamEmitter) Start(src Source, bindIP string, bindPort int) error { subFetchSize, subFetchData, } + e.closed = false return nil } func (e *streamEmitter) Stop() error { + e.mutex.Lock() + defer e.mutex.Unlock() + + if e.closed { + return nil + } + if e.subs != nil { for _, sub := range e.subs { sub.Unsubscribe() @@ -418,6 +440,7 @@ func (e *streamEmitter) Stop() error { if e.server != nil { e.server.Shutdown() } + e.closed = true return nil } @@ -638,11 +661,16 @@ func (e *streamEmitter) replyFetchData(conn *nats.Conn, src Source) nats.MsgHand func (e *streamEmitter) EmitInsert(filer indexer.Filer) error { e.mutex.RLock() src := e.src + closed := e.closed e.mutex.RUnlock() if src == nil { return errors.Errorf("maybe not Start") } + if closed { + // drop when closed, no emit to transmit differences when resuming server + return nil + } e.emitApplyCh <- emitApply{emit: emitInsert, insertFiler: filer} return nil @@ -650,11 +678,17 @@ func (e *streamEmitter) EmitInsert(filer indexer.Filer) error { func (e *streamEmitter) EmitDelete(key []byte) error { e.mutex.RLock() + src := e.src + closed := e.closed defer e.mutex.RUnlock() - if e.src == nil { + if src == nil { return errors.Errorf("maybe not Start") } + if closed { + // drop when closed, no emit to transmit differences when resuming server + return nil + } e.emitApplyCh <- emitApply{emit: emitDelete, deleteKey: key} return nil @@ -827,6 +861,7 @@ func NewStreamEmitter(ctx runtime.Context, logger *log.Logger, tempDir string, m releaseTTL: defaultReleaseTTL, emitApplyCh: make(chan emitApply, 1024), done: make(chan struct{}), + closed: false, src: nil, server: nil, emitConn: nil, @@ -856,34 +891,58 @@ type streamFetchDataEntry struct { release releaseFunc } +func (r *streamReciver) reconnect(conn *nats.Conn) { + r.logger.Printf("info: reconnected: %s", conn.ConnectedUrl()) + + r.mutex.Lock() + for _, sub := range r.subs { + sub.Unsubscribe() + } + r.subs = nil + r.doneBehind = false + r.mutex.Unlock() + + if err := r.recvStart(r.dst, conn); err != nil { + r.logger.Printf("error: reconnect recvStart() failure: %+v", err) + } +} + func (r *streamReciver) Start(dst Destination, serverIP string, serverPort int) error { natsUrl := fmt.Sprintf("nats://%s:%d", serverIP, serverPort) - client, err := conn(natsUrl, "client") + client, err := conn(natsUrl, "client", r.reconnect) if err != nil { return errors.Wrapf(err, "nats client connect: %s", natsUrl) } + if err := r.recvStart(dst, client); err != nil { + return errors.WithStack(err) + } + return nil +} + +func (r *streamReciver) recvStart(dst Destination, conn *nats.Conn) error { repliTemp, err := openTemporaryRepliData(r.ctx, r.tempDir) if err != nil { return errors.Wrap(err, "temporary repli data open") } defer repliTemp.Remove() - subRepli, err := client.Subscribe(SubjectRepli, r.recvRepliData(client, dst, repliTemp)) + subRepli, err := conn.Subscribe(SubjectRepli, r.recvRepliData(conn, dst, repliTemp)) if err != nil { return errors.Wrapf(err, "failed to subscribe %s", SubjectRepli) } - client.Flush() + conn.Flush() r.mergeWait.Add(1) - if err := r.requestBehindData(client, dst, repliTemp); err != nil { + if err := r.requestBehindData(conn, dst, repliTemp); err != nil { return errors.Wrap(err, "failed to get behind reqests") } r.mutex.Lock() defer r.mutex.Unlock() + r.dst = dst - r.client = client + r.client = conn r.subs = []*nats.Subscription{ subRepli, } @@ -906,6 +965,7 @@ func (r *streamReciver) Stop() error { r.client.Drain() r.client.Close() } + r.doneBehind = false return nil } @@ -1345,15 +1405,16 @@ func openTemporaryRepliData(ctx runtime.Context, tempDir string) (*temporaryRepl }, nil } -func conn(url string, name string) (*nats.Conn, error) { +func conn(url string, name string, reconnect nats.ConnHandler) (*nats.Conn, error) { return nats.Connect( url, nats.NoEcho(), nats.DontRandomize(), nats.Name(name), - nats.ReconnectJitter(100*time.Millisecond, 1000*time.Millisecond), + nats.ReconnectJitter(100*time.Millisecond, 300*time.Millisecond), nats.ReconnectWait(100*time.Millisecond), nats.MaxReconnects(-1), nats.PingInterval(10*time.Second), + nats.ReconnectHandler(reconnect), ) } diff --git a/repli/stream_test.go b/repli/stream_test.go index b554cec..e92e07c 100644 --- a/repli/stream_test.go +++ b/repli/stream_test.go @@ -2324,6 +2324,3 @@ func TestRepliTemporaryRepliData(t *testing.T) { } }) } - -func TestRepliStreamReconnect(t *testing.T) { -} diff --git a/repli_test.go b/repli_test.go index 9873dba..6653973 100644 --- a/repli_test.go +++ b/repli_test.go @@ -2,6 +2,7 @@ package bitcaskdb import ( "bytes" + "io" "os" "testing" "time" @@ -138,3 +139,87 @@ func TestBitcaskRepli(t *testing.T) { } }) } + +func TestBitcaskRepliReconnect(t *testing.T) { + srcdir, err := os.MkdirTemp("", "bitcask_src") + if err != nil { + t.Fatalf("no error %+v", err) + } + dstdir, err := os.MkdirTemp("", "bitcask_dst") + if err != nil { + t.Fatalf("no error %+v", err) + } + + srcdb, err := Open(srcdir, WithRepli("127.0.0.1", 4220)) + if err != nil { + t.Fatalf("no error %+v", err) + } + t.Cleanup(func() { + srcdb.Close() + }) + + dstdb, err := Open(dstdir, WithRepliClient("127.0.0.1", 4220)) + if err != nil { + t.Fatalf("no error %+v", err) + } + t.Cleanup(func() { + dstdb.Close() + }) + + t.Run("check_repli", func(tt *testing.T) { + if err := srcdb.Put([]byte("test1"), bytes.NewReader([]byte("value1"))); err != nil { + tt.Fatalf("no error %+v", err) + } + + time.Sleep(100 * time.Millisecond) + + r, err := dstdb.Get([]byte("test1")) + if err != nil { + tt.Fatalf("no error %+v", err) + } + defer r.Close() + + data, err := io.ReadAll(r) + if err != nil { + tt.Fatalf("no error %+v", err) + } + if bytes.Equal([]byte("value1"), data) != true { + tt.Errorf("expect:value1 actual:%s", data) + } + }) + + t.Run("stop_repli", func(tt *testing.T) { + // When stopped for some reason + if err := srcdb.repliEmit.Stop(); err != nil { + tt.Fatalf("no error %+v", err) + } + + if err := srcdb.Put([]byte("test2"), bytes.NewReader([]byte("value2"))); err != nil { + tt.Fatalf("no error %+v", err) + } + }) + + t.Run("restart_repli", func(tt *testing.T) { + // When recovery + if err := srcdb.repliEmit.Start(srcdb.repliSource(), srcdb.opt.RepliBindIP, srcdb.opt.RepliBindPort); err != nil { + tt.Fatalf("no error %+v", err) + } + + // reconnect wait + time.Sleep(1000 * time.Millisecond) + + r, err := dstdb.Get([]byte("test2")) + if err != nil { + tt.Fatalf("no error %+v", err) + } + defer r.Close() + + data, err := io.ReadAll(r) + if err != nil { + tt.Fatalf("no error %+v", err) + } + if bytes.Equal([]byte("value2"), data) != true { + tt.Errorf("expect:value2 actual:%s", data) + } + }) +}