Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed unnecessary nested transport #59

Merged
merged 2 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (f *fronted) prepopulateFronts(cacheFile string) {
// update last succeeded status of masquerades based on cached values
for _, fr := range f.fronts {
for _, cf := range cachedFronts {
sameFront := cf.ProviderID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress()
sameFront := cf.providerID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress()
cachedValueFresh := now.Sub(fr.lastSucceeded()) < f.maxAllowedCachedAge
if sameFront && cachedValueFresh {
fr.setLastSucceeded(cf.LastSucceeded)
Expand Down
8 changes: 4 additions & 4 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func TestCaching(t *testing.T) {
}

now := time.Now()
mb := &front{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID}
mc := &front{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted
md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped
mb := &front{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, providerID: testProviderID}
mc := &front{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, providerID: ""} // defaulted
md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, providerID: "sadcloud"} // skipped

f := makeFronted()

Expand Down Expand Up @@ -80,7 +80,7 @@ func TestCaching(t *testing.T) {
for i, expected := range []*front{mb, mc, md} {
require.Equal(t, expected.Domain, masquerades[i].Domain, "Wrong masquerade at position %d", i)
require.Equal(t, expected.IpAddress, masquerades[i].IpAddress, "Masquerade at position %d has wrong IpAddress", 0)
require.Equal(t, expected.ProviderID, masquerades[i].ProviderID, "Masquerade at position %d has wrong ProviderID", 0)
require.Equal(t, expected.providerID, masquerades[i].providerID, "Masquerade at position %d has wrong ProviderID", 0)
require.Equal(t, now.Unix(), masquerades[i].LastSucceeded.Unix(), "Masquerade at position %d has wrong LastSucceeded", 0)
}
f.Close()
Expand Down
66 changes: 15 additions & 51 deletions connected_roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package fronted

import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -40,17 +38,13 @@ func (crt connectedRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
// so it is returned as good.
crt.Conn.Close()
crt.front.markWithResult(true)
err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s",
crt.front.getProviderID(), originHost)
err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent", originHost)
op.FailIf(err)
return nil, err
}
log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, crt.front.getProviderID())
log.Debugf("Translated origin %s -> %s.", originHost, frontedHost)

reqi, err := withDomainFront(req, frontedHost, req.Body)
if err != nil {
return nil, op.FailIf(log.Errorf("Failed to copy http request with origin translated to %v?: %v", frontedHost, err))
}
reqi := withDomainFront(req, frontedHost)
disableKeepAlives := true
if strings.EqualFold(reqi.Header.Get("Connection"), "upgrade") {
disableKeepAlives = false
Expand Down Expand Up @@ -79,56 +73,26 @@ func (crt connectedRoundTripper) RoundTrip(req *http.Request) (*http.Response, e

// connectedConnHTTPTransport uses a preconnected connection to the CDN to make HTTP requests.
// This uses the pre-established connection to the CDN on the fronting domain.
func connectedConnHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripper {
return &connectedTransport{
Transport: http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return conn, nil
},
TLSHandshakeTimeout: 20 * time.Second,
DisableKeepAlives: disableKeepAlives,
IdleConnTimeout: 70 * time.Second,
func connectedConnHTTPTransport(conn net.Conn, disableKeepAlives bool) *http.Transport {
return &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return conn, nil
},
TLSHandshakeTimeout: 20 * time.Second,
DisableKeepAlives: disableKeepAlives,
IdleConnTimeout: 70 * time.Second,
}
}

// connectedTransport is a wrapper struct enabling us to modify the protocol of outgoing
// requests to make them all HTTP instead of potentially HTTPS, which breaks our particular
// implemenation of direct domain fronting.
type connectedTransport struct {
http.Transport
}

func (ct *connectedTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
defer func(op ops.Op) { op.End() }(ops.Begin("direct_transport_roundtrip"))

func withDomainFront(req *http.Request, frontedHost string) *http.Request {
// The connection is already encrypted by domain fronting. We need to rewrite URLs starting
// with "https://" to "http://", lest we get an error for doubling up on TLS.

// The RoundTrip interface requires that we not modify the memory in the request, so we just
// create a copy.
norm := new(http.Request)
*norm = *req // includes shallow copies of maps, but okay
norm.URL = new(url.URL)
*norm.URL = *req.URL
norm.URL.Scheme = "http"
return ct.Transport.RoundTrip(norm)
}
newReq := req.Clone(req.Context())
newReq.URL.Scheme = "http"
newReq.URL.Host = frontedHost

func withDomainFront(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) {
urlCopy := *req.URL
urlCopy.Host = frontedHost
r, err := http.NewRequestWithContext(req.Context(), req.Method, urlCopy.String(), body)
if err != nil {
return nil, err
}

for k, vs := range req.Header {
if !strings.EqualFold(k, "Host") {
v := make([]string, len(vs))
copy(v, vs)
r.Header[k] = v
}
}
return r, nil
return newReq
}
20 changes: 11 additions & 9 deletions front.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,15 @@ type front struct {
Masquerade
// lastSucceeded: the most recent time at which this Masquerade succeeded
LastSucceeded time.Time
// id of DirectProvider that this masquerade is provided by
ProviderID string
mx sync.RWMutex
cacheDirty chan interface{}
providerID string
mx sync.RWMutex
cacheDirty chan interface{}
}

func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}) Front {
return &front{
Masquerade: *m,
ProviderID: providerID,
providerID: providerID,
LastSucceeded: time.Time{},
cacheDirty: cacheDirty,
}
Expand All @@ -104,6 +103,9 @@ func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) (
ServerName: fr.Domain,
RootCAs: rootCAs,
}

// Set a fairly aggressive dial timeout based on observed timeouts running
// pinger in censored regions.
dialTimeout := 5 * time.Second
addr := fr.IpAddress
var sendServerNameExtension bool
Expand Down Expand Up @@ -176,19 +178,19 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b
return true
}

// getDomain implements MasqueradeInterface.
// getDomain implements Front.
func (fr *front) getDomain() string {
return fr.Domain
}

// getIpAddress implements MasqueradeInterface.
// getIpAddress implements Front.
func (fr *front) getIpAddress() string {
return fr.IpAddress
}

// getProviderID implements MasqueradeInterface.
// getProviderID implements Front.
func (fr *front) getProviderID() string {
return fr.ProviderID
return fr.providerID
}

// MarshalJSON marshals masquerade into json
Expand Down
Loading