Skip to content

Commit

Permalink
Pulled out peer networking
Browse files Browse the repository at this point in the history
Signed-off-by: Jimmy Moore <[email protected]>
  • Loading branch information
jimmyaxod committed Jan 20, 2025
1 parent a37a210 commit e050198
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 96 deletions.
1 change: 1 addition & 0 deletions cmd/drafter-peer/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func decodeDevices(data string) ([]CompositeDevices, error) {
err := json.Unmarshal([]byte(data), &devices)
return devices, err
}

func getDefaultDevices() string {
defaultDevices, err := json.Marshal([]CompositeDevices{
{
Expand Down
123 changes: 27 additions & 96 deletions cmd/drafter-peer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ package main

import (
"context"
"errors"
"flag"
"io"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"sync"
"time"

"github.com/loopholelabs/drafter/pkg/common"
Expand All @@ -31,10 +28,6 @@ func main() {

// General flags
rawDevices := flag.String("devices", getDefaultDevices(), "Devices configuration")
devices, err := decodeDevices(*rawDevices)
if err != nil {
panic(err)
}

raddr := flag.String("raddr", "localhost:1337", "Remote address to connect to (leave empty to disable)")
laddr := flag.String("laddr", "localhost:1337", "Local address to listen on (leave empty to disable)")
Expand All @@ -61,9 +54,14 @@ func main() {

flag.Parse()

devices, err := decodeDevices(*rawDevices)
if err != nil {
panic(err)
}

// FIXME: Allow tweak from cmdline
log := logging.New(logging.Zerolog, "drafter", os.Stderr)
log.SetLevel(types.TraceLevel)
log.SetLevel(types.DebugLevel)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -111,23 +109,6 @@ func main() {
panic(err)
}

var (
readers []io.Reader
writers []io.Writer
)
if strings.TrimSpace(*raddr) != "" {
conn, err := (&net.Dialer{}).DialContext(goroutineManager.Context(), "tcp", *raddr)
if err != nil {
panic(err)
}
defer conn.Close()

log.Info().Str("remote", conn.RemoteAddr().String()).Msg("Migrating from")

readers = []io.Reader{conn}
writers = []io.Writer{conn}
}

p, err := peer.StartPeer[struct{}, ipc.AgentServerRemote[struct{}]](
goroutineManager.Context(),
context.Background(), // Never give up on rescue operations
Expand Down Expand Up @@ -186,6 +167,21 @@ func main() {
})
}

var readers []io.Reader
var writers []io.Writer
var closer io.Closer
var remoteAddr string

if *raddr != "" {
closer, readers, writers, remoteAddr, err = connectAddr(goroutineManager.Context(), *raddr)
if err != nil {
panic(err)
}
defer closer.Close()

log.Info().Str("remote", remoteAddr).Msg("Migrating from")
}

migratedPeer, err := p.MigrateFrom(
goroutineManager.Context(),
migrateFromDevices,
Expand Down Expand Up @@ -303,76 +299,11 @@ func main() {
}
}

var (
closeLock sync.Mutex
closed bool
)
lis, err := net.Listen("tcp", *laddr)
if err != nil {
panic(err)
}
defer func() {
defer goroutineManager.CreateForegroundPanicCollector()()

closeLock.Lock()
closed = true
closeLock.Unlock()

if err := lis.Close(); err != nil {
panic(err)
}
}()

log.Info().Str("addr", lis.Addr().String()).Msg("Listening for connections")

var (
ready = make(chan struct{})
signalReady = sync.OnceFunc(func() {
close(ready) // We can safely close() this channel since the caller only runs once/is `sync.OnceFunc`d
})
)

var conn net.Conn
goroutineManager.StartForegroundGoroutine(func(_ context.Context) {
conn, err = lis.Accept()
if err != nil {
closeLock.Lock()
defer closeLock.Unlock()

if closed && errors.Is(err, net.ErrClosed) { // Don't treat closed errors as errors if we closed the connection
if err := goroutineManager.Context().Err(); err != nil {
panic(err)
}

return
}
panic(err)
}
signalReady()
})

bubbleSignals = true

select {
case <-goroutineManager.Context().Done():
return

case <-done:
before = time.Now()
if err := resumedPeer.SuspendAndCloseAgentServer(goroutineManager.Context(), *resumeTimeout); err != nil {
panic(err)
}

log.Info().Int64("ms", time.Since(before).Milliseconds()).Msg("Suspend. Shutting down.")
return

case <-ready:
break
}

defer conn.Close()
log.Info().Str("addr", *laddr).Msg("Listening for connections")
closer, readers, writers, remoteAddr, err = listenAddr(goroutineManager.Context(), *laddr)
defer closer.Close()

log.Info().Str("addr", conn.RemoteAddr().String()).Msg("Migrating to")
log.Info().Str("addr", remoteAddr).Msg("Migrating to")

migrateToDevices := []common.MigrateToDevice{}
for _, device := range devices {
Expand All @@ -395,8 +326,8 @@ func main() {
migrateToDevices,
*resumeTimeout,
*concurrency,
[]io.Reader{conn},
[]io.Writer{conn},
readers,
writers,
peer.MigrateToHooks{
OnBeforeSuspend: func() {
before = time.Now()
Expand Down
34 changes: 34 additions & 0 deletions cmd/drafter-peer/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package main

import (
"context"
"io"
"net"
)

func connectAddr(ctx context.Context, addr string) (io.Closer, []io.Reader, []io.Writer, string, error) {
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, nil, nil, "", err
}
readers := []io.Reader{conn}
writers := []io.Writer{conn}
return conn, readers, writers, conn.RemoteAddr().String(), nil
}

func listenAddr(ctx context.Context, addr string) (io.Closer, []io.Reader, []io.Writer, string, error) {
lis, err := net.Listen("tcp", addr)
if err != nil {
return nil, nil, nil, "", err
}
defer lis.Close()

conn, err := lis.Accept()
if err != nil {
return nil, nil, nil, "", err
}

readers := []io.Reader{conn}
writers := []io.Writer{conn}
return conn, readers, writers, conn.RemoteAddr().String(), nil
}

0 comments on commit e050198

Please sign in to comment.