From 76e8446a26e23c85101c5cbf2355bf907265f010 Mon Sep 17 00:00:00 2001 From: Andrew Dunstall Date: Sat, 8 Mar 2025 11:57:20 +0000 Subject: [PATCH] agent: add connect.proxy-url flag Adds support for the agent connecting to the Piko server via a HTTP proxy with '--connect.proxy-url'. --- agent/config/config.go | 19 ++++++++++++++++++- cli/agent/command.go | 10 ++++++++++ client/upstream.go | 5 +++++ pkg/websocket/conn.go | 23 +++++++++++++++++++++-- 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/agent/config/config.go b/agent/config/config.go index 675e2481..c56fe217 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -246,6 +246,10 @@ type ConnectConfig struct { Timeout time.Duration `json:"timeout" yaml:"timeout"` TLS TLSConfig `json:"tls" yaml:"tls"` + + // ProxyURL is the proxy URL to proxy the request from the agent to the + // Piko server (optional). + ProxyURL string `json:"proxy_url" yaml:"proxy_url"` } func (c *ConnectConfig) Validate() error { @@ -255,6 +259,11 @@ func (c *ConnectConfig) Validate() error { if _, err := url.Parse(c.URL); err != nil { return fmt.Errorf("invalid url: %w", err) } + if c.ProxyURL != "" { + if _, err := url.Parse(c.ProxyURL); err != nil { + return fmt.Errorf("invalid proxy url: %w", err) + } + } if c.Timeout == 0 { return fmt.Errorf("missing timeout") } @@ -287,7 +296,7 @@ Token is a token to authenticate with the Piko server.`, "connect.tenant-id", c.TenantID, ` -Tenant ID of the agent. +Tenant ID of the agent (optional). Tenants can be used to configure different authentication mechanisms and keys for different upstream services.`, @@ -304,6 +313,14 @@ reconnect.`, ) c.TLS.RegisterFlags(fs, "connect") + + fs.StringVar( + &c.ProxyURL, + "connect.proxy-url", + c.ProxyURL, + ` +The proxy URL to proxy the request from the agent to the Piko server (optional).`, + ) } type ServerConfig struct { diff --git a/cli/agent/command.go b/cli/agent/command.go index d61055ba..63988c14 100644 --- a/cli/agent/command.go +++ b/cli/agent/command.go @@ -116,11 +116,21 @@ func runAgent(conf *config.Config, logger log.Logger) error { // Already verified in conf.Validate() so this shouldn't happen. return fmt.Errorf("connect url: %w", err) } + var proxyURL *url.URL + if conf.Connect.ProxyURL != "" { + connectProxyURL, err := url.Parse(conf.Connect.ProxyURL) + if err != nil { + // Already verified in conf.Validate() so this shouldn't happen. + return fmt.Errorf("connect proxy url: %w", err) + } + proxyURL = connectProxyURL + } upstream := &client.Upstream{ URL: connectURL, Token: conf.Connect.Token, TenantID: conf.Connect.TenantID, TLSConfig: connectTLSConfig, + ProxyURL: proxyURL, Logger: logger.WithSubsystem("client"), } diff --git a/client/upstream.go b/client/upstream.go index d6e685b3..1014f1da 100644 --- a/client/upstream.go +++ b/client/upstream.go @@ -54,6 +54,10 @@ type Upstream struct { // If nil, the default configuration is used. TLSConfig *tls.Config + // ProxyURL is the URL to proxy the request from the client to the Piko + // server (optional). + ProxyURL *url.URL + // MinReconnectBackoff is the minimum backoff when reconnecting. // // Defaults to 100ms. @@ -120,6 +124,7 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi websocket.WithToken(u.Token), websocket.WithTenantID(u.TenantID), websocket.WithTLSConfig(u.TLSConfig), + websocket.WithProxyURL(u.ProxyURL), ) if err == nil { u.logger().Debug( diff --git a/pkg/websocket/conn.go b/pkg/websocket/conn.go index a053318f..e1f25ad1 100644 --- a/pkg/websocket/conn.go +++ b/pkg/websocket/conn.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "net/url" "strings" "time" @@ -51,6 +52,7 @@ type dialOptions struct { token string tenantID string tlsConfig *tls.Config + proxyURL *url.URL } type DialOption interface { @@ -77,6 +79,18 @@ func WithTenantID(tenantID string) DialOption { return tenantIDOption(tenantID) } +type proxyURLOption struct { + url *url.URL +} + +func (o proxyURLOption) apply(opts *dialOptions) { + opts.proxyURL = o.url +} + +func WithProxyURL(url *url.URL) DialOption { + return &proxyURLOption{url: url} +} + type tlsConfigOption struct { TLSConfig *tls.Config } @@ -106,7 +120,7 @@ func New(wsConn *websocket.Conn) *Conn { } } -func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) { +func Dial(ctx context.Context, u string, opts ...DialOption) (*Conn, error) { options := dialOptions{} for _, o := range opts { o.apply(&options) @@ -115,6 +129,11 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) { dialer := &websocket.Dialer{ HandshakeTimeout: 60 * time.Second, } + if options.proxyURL != nil { + dialer.Proxy = func(*http.Request) (*url.URL, error) { + return options.proxyURL, nil + } + } if options.tlsConfig != nil { dialer.TLSClientConfig = options.tlsConfig @@ -129,7 +148,7 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) { } wsConn, resp, err := dialer.DialContext( - ctx, url, header, + ctx, u, header, ) if err == nil { return New(wsConn), nil