Skip to content

Commit

Permalink
agent: add connect.proxy-url flag
Browse files Browse the repository at this point in the history
Adds support for the agent connecting to the Piko server via a HTTP
proxy with '--connect.proxy-url'.
  • Loading branch information
andydunstall committed Mar 8, 2025
1 parent c4d56d2 commit 76e8446
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
19 changes: 18 additions & 1 deletion agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ type ConnectConfig struct {
Timeout time.Duration `json:"timeout" yaml:"timeout"`

TLS TLSConfig `json:"tls" yaml:"tls"`

// ProxyURL is the proxy URL to proxy the request from the agent to the
// Piko server (optional).
ProxyURL string `json:"proxy_url" yaml:"proxy_url"`
}

func (c *ConnectConfig) Validate() error {
Expand All @@ -255,6 +259,11 @@ func (c *ConnectConfig) Validate() error {
if _, err := url.Parse(c.URL); err != nil {
return fmt.Errorf("invalid url: %w", err)
}
if c.ProxyURL != "" {
if _, err := url.Parse(c.ProxyURL); err != nil {
return fmt.Errorf("invalid proxy url: %w", err)
}
}
if c.Timeout == 0 {
return fmt.Errorf("missing timeout")
}
Expand Down Expand Up @@ -287,7 +296,7 @@ Token is a token to authenticate with the Piko server.`,
"connect.tenant-id",
c.TenantID,
`
Tenant ID of the agent.
Tenant ID of the agent (optional).
Tenants can be used to configure different authentication mechanisms and keys
for different upstream services.`,
Expand All @@ -304,6 +313,14 @@ reconnect.`,
)

c.TLS.RegisterFlags(fs, "connect")

fs.StringVar(
&c.ProxyURL,
"connect.proxy-url",
c.ProxyURL,
`
The proxy URL to proxy the request from the agent to the Piko server (optional).`,
)
}

type ServerConfig struct {
Expand Down
10 changes: 10 additions & 0 deletions cli/agent/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,21 @@ func runAgent(conf *config.Config, logger log.Logger) error {
// Already verified in conf.Validate() so this shouldn't happen.
return fmt.Errorf("connect url: %w", err)
}
var proxyURL *url.URL
if conf.Connect.ProxyURL != "" {
connectProxyURL, err := url.Parse(conf.Connect.ProxyURL)
if err != nil {
// Already verified in conf.Validate() so this shouldn't happen.
return fmt.Errorf("connect proxy url: %w", err)
}
proxyURL = connectProxyURL
}
upstream := &client.Upstream{
URL: connectURL,
Token: conf.Connect.Token,
TenantID: conf.Connect.TenantID,
TLSConfig: connectTLSConfig,
ProxyURL: proxyURL,
Logger: logger.WithSubsystem("client"),
}

Expand Down
5 changes: 5 additions & 0 deletions client/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type Upstream struct {
// If nil, the default configuration is used.
TLSConfig *tls.Config

// ProxyURL is the URL to proxy the request from the client to the Piko
// server (optional).
ProxyURL *url.URL

// MinReconnectBackoff is the minimum backoff when reconnecting.
//
// Defaults to 100ms.
Expand Down Expand Up @@ -120,6 +124,7 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi
websocket.WithToken(u.Token),
websocket.WithTenantID(u.TenantID),
websocket.WithTLSConfig(u.TLSConfig),
websocket.WithProxyURL(u.ProxyURL),
)
if err == nil {
u.logger().Debug(
Expand Down
23 changes: 21 additions & 2 deletions pkg/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -51,6 +52,7 @@ type dialOptions struct {
token string
tenantID string
tlsConfig *tls.Config
proxyURL *url.URL
}

type DialOption interface {
Expand All @@ -77,6 +79,18 @@ func WithTenantID(tenantID string) DialOption {
return tenantIDOption(tenantID)
}

type proxyURLOption struct {
url *url.URL
}

func (o proxyURLOption) apply(opts *dialOptions) {
opts.proxyURL = o.url
}

func WithProxyURL(url *url.URL) DialOption {
return &proxyURLOption{url: url}
}

type tlsConfigOption struct {
TLSConfig *tls.Config
}
Expand Down Expand Up @@ -106,7 +120,7 @@ func New(wsConn *websocket.Conn) *Conn {
}
}

func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
func Dial(ctx context.Context, u string, opts ...DialOption) (*Conn, error) {
options := dialOptions{}
for _, o := range opts {
o.apply(&options)
Expand All @@ -115,6 +129,11 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
dialer := &websocket.Dialer{
HandshakeTimeout: 60 * time.Second,
}
if options.proxyURL != nil {
dialer.Proxy = func(*http.Request) (*url.URL, error) {
return options.proxyURL, nil
}
}

if options.tlsConfig != nil {
dialer.TLSClientConfig = options.tlsConfig
Expand All @@ -129,7 +148,7 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
}

wsConn, resp, err := dialer.DialContext(
ctx, url, header,
ctx, u, header,
)
if err == nil {
return New(wsConn), nil
Expand Down

0 comments on commit 76e8446

Please sign in to comment.