From 0a77ce4d64b395c4a942f64aa56cb942a314ef87 Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 18 Nov 2023 16:19:08 -0800 Subject: [PATCH 1/2] feat: client handshake info --- app/cmd/client.go | 11 +- app/cmd/ping.go | 7 +- core/client/client.go | 32 ++++-- core/client/reconnect.go | 9 +- core/internal/integration_tests/close_test.go | 8 +- core/internal/integration_tests/smoke_test.go | 107 +++++++++++++++++- .../internal/integration_tests/stress_test.go | 4 +- .../integration_tests/trafficlogger_test.go | 4 +- 8 files changed, 148 insertions(+), 34 deletions(-) diff --git a/app/cmd/client.go b/app/cmd/client.go index 823249a6bb..8c593c082a 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -402,8 +402,8 @@ func runClient(cmd *cobra.Command, args []string) { logger.Fatal("failed to load client config", zap.Error(err)) } - c, err := client.NewReconnectableClient(hyConfig, func(c client.Client, count int) { - connectLog(count) + c, err := client.NewReconnectableClient(hyConfig, func(c client.Client, info *client.HandshakeInfo, count int) { + connectLog(info, count) // On the client side, we start checking for updates after we successfully connect // to the server, which, depending on whether lazy mode is enabled, may or may not // be immediately after the client starts. We don't want the update check request @@ -699,8 +699,11 @@ func (f *adaptiveConnFactory) New(addr net.Addr) (net.PacketConn, error) { } } -func connectLog(count int) { - logger.Info("connected to server", zap.Int("count", count)) +func connectLog(info *client.HandshakeInfo, count int) { + logger.Info("connected to server", + zap.Bool("udpEnabled", info.UDPEnabled), + zap.Uint64("tx", info.Tx), + zap.Int("count", count)) } type socks5Logger struct{} diff --git a/app/cmd/ping.go b/app/cmd/ping.go index ccaf870f81..856595b920 100644 --- a/app/cmd/ping.go +++ b/app/cmd/ping.go @@ -42,13 +42,16 @@ func runPing(cmd *cobra.Command, args []string) { logger.Fatal("failed to load client config", zap.Error(err)) } - c, err := client.NewClient(hyConfig) + c, info, err := client.NewClient(hyConfig) if err != nil { logger.Fatal("failed to initialize client", zap.Error(err)) } defer c.Close() + logger.Info("connected to server", + zap.Bool("udpEnabled", info.UDPEnabled), + zap.Uint64("tx", info.Tx)) - logger.Info("connecting", zap.String("address", addr)) + logger.Info("connecting", zap.String("addr", addr)) start := time.Now() conn, err := c.TCP(addr) if err != nil { diff --git a/core/client/client.go b/core/client/client.go index 535a83b84a..9f4d001278 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -34,17 +34,23 @@ type HyUDPConn interface { Close() error } -func NewClient(config *Config) (Client, error) { +type HandshakeInfo struct { + UDPEnabled bool + Tx uint64 // 0 if using BBR +} + +func NewClient(config *Config) (Client, *HandshakeInfo, error) { if err := config.verifyAndFill(); err != nil { - return nil, err + return nil, nil, err } c := &clientImpl{ config: config, } - if err := c.connect(); err != nil { - return nil, err + info, err := c.connect() + if err != nil { + return nil, nil, err } - return c, nil + return c, info, nil } type clientImpl struct { @@ -56,10 +62,10 @@ type clientImpl struct { udpSM *udpSessionManager } -func (c *clientImpl) connect() error { +func (c *clientImpl) connect() (*HandshakeInfo, error) { pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) if err != nil { - return err + return nil, err } // Convert config to TLS config & QUIC config tlsConfig := &tls.Config{ @@ -113,22 +119,23 @@ func (c *clientImpl) connect() error { _ = conn.CloseWithError(closeErrCodeProtocolError, "") } _ = pktConn.Close() - return coreErrs.ConnectError{Err: err} + return nil, coreErrs.ConnectError{Err: err} } if resp.StatusCode != protocol.StatusAuthOK { _ = conn.CloseWithError(closeErrCodeProtocolError, "") _ = pktConn.Close() - return coreErrs.AuthError{StatusCode: resp.StatusCode} + return nil, coreErrs.AuthError{StatusCode: resp.StatusCode} } // Auth OK authResp := protocol.AuthResponseFromHeader(resp.Header) + var actualTx uint64 if authResp.RxAuto { // Server asks client to use bandwidth detection, // ignore local bandwidth config and use BBR congestion.UseBBR(conn) } else { // actualTx = min(serverRx, clientTx) - actualTx := authResp.Rx + actualTx = authResp.Rx if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { // Server doesn't have a limit, or our clientTx is smaller than serverRx actualTx = c.config.BandwidthConfig.MaxTx @@ -147,7 +154,10 @@ func (c *clientImpl) connect() error { if authResp.UDPEnabled { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) } - return nil + return &HandshakeInfo{ + UDPEnabled: authResp.UDPEnabled, + Tx: actualTx, + }, nil } // openStream wraps the stream with QStream, which handles Close() properly diff --git a/core/client/reconnect.go b/core/client/reconnect.go index 61f314af52..c9f5957271 100644 --- a/core/client/reconnect.go +++ b/core/client/reconnect.go @@ -12,13 +12,14 @@ import ( type reconnectableClientImpl struct { config *Config client Client + info *HandshakeInfo count int - connectedFunc func(Client, int) // called when successfully connected + connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected m sync.Mutex closed bool // permanent close } -func NewReconnectableClient(config *Config, connectedFunc func(Client, int), lazy bool) (Client, error) { +func NewReconnectableClient(config *Config, connectedFunc func(Client, *HandshakeInfo, int), lazy bool) (Client, error) { // Make sure we capture any error in config and return it here, // so that the caller doesn't have to wait until the first call // to TCP() or UDP() to get the error (when lazy is true). @@ -42,13 +43,13 @@ func (rc *reconnectableClientImpl) reconnect() error { _ = rc.client.Close() } var err error - rc.client, err = NewClient(rc.config) + rc.client, rc.info, err = NewClient(rc.config) if err != nil { return err } else { rc.count++ if rc.connectedFunc != nil { - rc.connectedFunc(rc, rc.count) + rc.connectedFunc(rc, rc.info, rc.count) } return nil } diff --git a/core/internal/integration_tests/close_test.go b/core/internal/integration_tests/close_test.go index 57b8ce0431..4160b3cd5b 100644 --- a/core/internal/integration_tests/close_test.go +++ b/core/internal/integration_tests/close_test.go @@ -34,7 +34,7 @@ func TestClientServerTCPClose(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -116,7 +116,7 @@ func TestClientServerUDPIdleTimeout(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -194,7 +194,7 @@ func TestClientServerClientShutdown(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -223,7 +223,7 @@ func TestClientServerServerShutdown(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, QUICConfig: client.QUICConfig{ diff --git a/core/internal/integration_tests/smoke_test.go b/core/internal/integration_tests/smoke_test.go index 146a8dd4ed..ab204cbf18 100644 --- a/core/internal/integration_tests/smoke_test.go +++ b/core/internal/integration_tests/smoke_test.go @@ -19,7 +19,7 @@ import ( // TestClientNoServer tests how the client handles a server address it cannot connect to. // NewClient should return a ConnectError. func TestClientNoServer(t *testing.T) { - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 55666}, }) assert.Nil(t, c) @@ -46,7 +46,7 @@ func TestClientServerBadAuth(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, Auth: "badpassword", TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, @@ -75,7 +75,7 @@ func TestClientServerUDPDisabled(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -113,7 +113,7 @@ func TestClientServerTCPEcho(t *testing.T) { go echoServer.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -160,7 +160,7 @@ func TestClientServerUDPEcho(t *testing.T) { go echoServer.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -181,3 +181,100 @@ func TestClientServerUDPEcho(t *testing.T) { assert.Equal(t, sData, rData) assert.Equal(t, echoAddr, rAddr) } + +// TestClientServerHandshakeInfo tests that the client returns the correct handshake info. +func TestClientServerHandshakeInfo(t *testing.T) { + // Create server 1, UDP enabled, unlimited bandwidth + udpConn, udpAddr, err := serverConn() + assert.NoError(t, err) + auth := mocks.NewMockAuthenticator(t) + auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") + s, err := server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Authenticator: auth, + }) + assert.NoError(t, err) + go s.Serve() + + // Create client 1, with specified tx bandwidth + c, info, err := client.NewClient(&client.Config{ + ServerAddr: udpAddr, + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + BandwidthConfig: client.BandwidthConfig{ + MaxTx: 123456, + }, + }) + assert.NoError(t, err) + assert.Equal(t, &client.HandshakeInfo{ + UDPEnabled: true, + Tx: 123456, + }, info) + + // Close server 1 and client 1 + _ = s.Close() + _ = c.Close() + + // Create server 2, UDP disabled, limited rx bandwidth + udpConn, udpAddr, err = serverConn() + assert.NoError(t, err) + s, err = server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + BandwidthConfig: server.BandwidthConfig{ + MaxRx: 100000, + }, + DisableUDP: true, + Authenticator: auth, + }) + assert.NoError(t, err) + go s.Serve() + + // Create client 2, with specified tx bandwidth + c, info, err = client.NewClient(&client.Config{ + ServerAddr: udpAddr, + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + BandwidthConfig: client.BandwidthConfig{ + MaxTx: 123456, + }, + }) + assert.NoError(t, err) + assert.Equal(t, &client.HandshakeInfo{ + UDPEnabled: false, + Tx: 100000, + }, info) + + // Close server 2 and client 2 + _ = s.Close() + _ = c.Close() + + // Create server 3, UDP enabled, ignore client bandwidth + udpConn, udpAddr, err = serverConn() + assert.NoError(t, err) + s, err = server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + IgnoreClientBandwidth: true, + Authenticator: auth, + }) + assert.NoError(t, err) + go s.Serve() + + // Create client 3, with specified tx bandwidth + c, info, err = client.NewClient(&client.Config{ + ServerAddr: udpAddr, + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + BandwidthConfig: client.BandwidthConfig{ + MaxTx: 123456, + }, + }) + assert.NoError(t, err) + assert.Equal(t, &client.HandshakeInfo{ + UDPEnabled: true, + Tx: 0, + }, info) + + // Close server 3 and client 3 + _ = s.Close() + _ = c.Close() +} diff --git a/core/internal/integration_tests/stress_test.go b/core/internal/integration_tests/stress_test.go index 2324a9e45b..f10ac3adfc 100644 --- a/core/internal/integration_tests/stress_test.go +++ b/core/internal/integration_tests/stress_test.go @@ -148,7 +148,7 @@ func TestClientServerTCPStress(t *testing.T) { go echoServer.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -192,7 +192,7 @@ func TestClientServerUDPStress(t *testing.T) { go echoServer.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) diff --git a/core/internal/integration_tests/trafficlogger_test.go b/core/internal/integration_tests/trafficlogger_test.go index 3aa41194cd..ff1d66ee81 100644 --- a/core/internal/integration_tests/trafficlogger_test.go +++ b/core/internal/integration_tests/trafficlogger_test.go @@ -35,7 +35,7 @@ func TestClientServerTrafficLoggerTCP(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) @@ -114,7 +114,7 @@ func TestClientServerTrafficLoggerUDP(t *testing.T) { go s.Serve() // Create client - c, err := client.NewClient(&client.Config{ + c, _, err := client.NewClient(&client.Config{ ServerAddr: udpAddr, TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, }) From faeef50fc00ac0512bd75942839875f89341db8a Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 18 Nov 2023 21:02:21 -0800 Subject: [PATCH 2/2] chore: use local var for info --- core/client/reconnect.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/client/reconnect.go b/core/client/reconnect.go index c9f5957271..9a94bd3d83 100644 --- a/core/client/reconnect.go +++ b/core/client/reconnect.go @@ -12,7 +12,6 @@ import ( type reconnectableClientImpl struct { config *Config client Client - info *HandshakeInfo count int connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected m sync.Mutex @@ -43,13 +42,14 @@ func (rc *reconnectableClientImpl) reconnect() error { _ = rc.client.Close() } var err error - rc.client, rc.info, err = NewClient(rc.config) + var info *HandshakeInfo + rc.client, info, err = NewClient(rc.config) if err != nil { return err } else { rc.count++ if rc.connectedFunc != nil { - rc.connectedFunc(rc, rc.info, rc.count) + rc.connectedFunc(rc, info, rc.count) } return nil }