Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent: add connect.proxy-url flag #237

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ type ConnectConfig struct {
Timeout time.Duration `json:"timeout" yaml:"timeout"`

TLS TLSConfig `json:"tls" yaml:"tls"`

// ProxyURL is the proxy URL to proxy the request from the agent to the
// Piko server (optional).
ProxyURL string `json:"proxy_url" yaml:"proxy_url"`
}

func (c *ConnectConfig) Validate() error {
Expand All @@ -255,6 +259,11 @@ func (c *ConnectConfig) Validate() error {
if _, err := url.Parse(c.URL); err != nil {
return fmt.Errorf("invalid url: %w", err)
}
if c.ProxyURL != "" {
if _, err := url.Parse(c.ProxyURL); err != nil {
return fmt.Errorf("invalid proxy url: %w", err)
}
}
if c.Timeout == 0 {
return fmt.Errorf("missing timeout")
}
Expand Down Expand Up @@ -287,7 +296,7 @@ Token is a token to authenticate with the Piko server.`,
"connect.tenant-id",
c.TenantID,
`
Tenant ID of the agent.
Tenant ID of the agent (optional).
Tenants can be used to configure different authentication mechanisms and keys
for different upstream services.`,
Expand All @@ -304,6 +313,14 @@ reconnect.`,
)

c.TLS.RegisterFlags(fs, "connect")

fs.StringVar(
&c.ProxyURL,
"connect.proxy-url",
c.ProxyURL,
`
The proxy URL to proxy the request from the agent to the Piko server (optional).`,
)
}

type ServerConfig struct {
Expand Down
10 changes: 10 additions & 0 deletions cli/agent/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,21 @@ func runAgent(conf *config.Config, logger log.Logger) error {
// Already verified in conf.Validate() so this shouldn't happen.
return fmt.Errorf("connect url: %w", err)
}
var proxyURL *url.URL
if conf.Connect.ProxyURL != "" {
connectProxyURL, err := url.Parse(conf.Connect.ProxyURL)
if err != nil {
// Already verified in conf.Validate() so this shouldn't happen.
return fmt.Errorf("connect proxy url: %w", err)
}
proxyURL = connectProxyURL
}
upstream := &client.Upstream{
URL: connectURL,
Token: conf.Connect.Token,
TenantID: conf.Connect.TenantID,
TLSConfig: connectTLSConfig,
ProxyURL: proxyURL,
Logger: logger.WithSubsystem("client"),
}

Expand Down
5 changes: 5 additions & 0 deletions client/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type Upstream struct {
// If nil, the default configuration is used.
TLSConfig *tls.Config

// ProxyURL is the URL to proxy the request from the client to the Piko
// server (optional).
ProxyURL *url.URL

// MinReconnectBackoff is the minimum backoff when reconnecting.
//
// Defaults to 100ms.
Expand Down Expand Up @@ -120,6 +124,7 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi
websocket.WithToken(u.Token),
websocket.WithTenantID(u.TenantID),
websocket.WithTLSConfig(u.TLSConfig),
websocket.WithProxyURL(u.ProxyURL),
)
if err == nil {
u.logger().Debug(
Expand Down
23 changes: 21 additions & 2 deletions pkg/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -51,6 +52,7 @@ type dialOptions struct {
token string
tenantID string
tlsConfig *tls.Config
proxyURL *url.URL
}

type DialOption interface {
Expand All @@ -77,6 +79,18 @@ func WithTenantID(tenantID string) DialOption {
return tenantIDOption(tenantID)
}

type proxyURLOption struct {
url *url.URL
}

func (o proxyURLOption) apply(opts *dialOptions) {
opts.proxyURL = o.url
}

func WithProxyURL(url *url.URL) DialOption {
return &proxyURLOption{url: url}
}

type tlsConfigOption struct {
TLSConfig *tls.Config
}
Expand Down Expand Up @@ -106,7 +120,7 @@ func New(wsConn *websocket.Conn) *Conn {
}
}

func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
func Dial(ctx context.Context, u string, opts ...DialOption) (*Conn, error) {
options := dialOptions{}
for _, o := range opts {
o.apply(&options)
Expand All @@ -115,6 +129,11 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
dialer := &websocket.Dialer{
HandshakeTimeout: 60 * time.Second,
}
if options.proxyURL != nil {
dialer.Proxy = func(*http.Request) (*url.URL, error) {
return options.proxyURL, nil
}
}

if options.tlsConfig != nil {
dialer.TLSClientConfig = options.tlsConfig
Expand All @@ -129,7 +148,7 @@ func Dial(ctx context.Context, url string, opts ...DialOption) (*Conn, error) {
}

wsConn, resp, err := dialer.DialContext(
ctx, url, header,
ctx, u, header,
)
if err == nil {
return New(wsConn), nil
Expand Down
Loading