From 48427deafb440e8c56032d9d78f2f10c0bee3f02 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Wed, 5 Feb 2025 14:16:45 -0700 Subject: [PATCH] Added root CA config --- kindling.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/kindling.go b/kindling.go index c391a34..fedbf63 100644 --- a/kindling.go +++ b/kindling.go @@ -2,7 +2,9 @@ package kindling import ( "context" + "crypto/x509" "embed" + "encoding/pem" "fmt" "io" "log/slog" @@ -31,6 +33,7 @@ type httpDialer func(ctx context.Context, addr string) (http.RoundTripper, error type kindling struct { httpDialers []httpDialer logWriter io.Writer + rootCA string } // Make sure that kindling implements the Kindling interface. @@ -75,10 +78,10 @@ func WithDomainFronting(configURL, countryCode string) Option { } } -// WithDoHTunnel is a functional option that enables DNS over HTTPS (DoH) tunneling for the Kindling. -func WithDoHTunnel() Option { +// WithRootCA pins the root CA to use for TLS. +func WithRootCA(rootCA string) Option { return func(k *kindling) { - + k.rootCA = rootCA } } @@ -149,7 +152,7 @@ func (k *kindling) newSmartHTTPDialer(domains ...string) (httpDialer, error) { } return k.newTransportWithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { return streamConn, nil - }), nil + }) }, nil } @@ -165,11 +168,11 @@ func (k *kindling) newSmartHTTPTransport(domains ...string) (*http.Transport, er return nil, fmt.Errorf("failed to dial stream: %v", err) } return streamConn, nil - }), nil + }) } -func (k *kindling) newTransportWithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) *http.Transport { - return &http.Transport{ +func (k *kindling) newTransportWithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) (*http.Transport, error) { + tr := &http.Transport{ DialContext: dialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -177,6 +180,19 @@ func (k *kindling) newTransportWithDialContext(dialContext func(ctx context.Cont TLSHandshakeTimeout: 20 * time.Second, ExpectContinueTimeout: 4 * time.Second, } + if k.rootCA != "" { + block, _ := pem.Decode([]byte(k.rootCA)) + if block == nil { + return nil, fmt.Errorf("failed to decode root CA PEM block") + } + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(block.Bytes) { + log.Error("Failed to append root CA to pool") + return nil, fmt.Errorf("failed to append root CA to pool") + } + tr.TLSClientConfig.RootCAs = certPool + } + return tr, nil } //go:embed smart_dialer_config.yml