diff --git a/cache.go b/cache.go index 49f03e2..a1b6881 100644 --- a/cache.go +++ b/cache.go @@ -44,7 +44,7 @@ func (f *fronted) prepopulateFronts(cacheFile string) { // update last succeeded status of masquerades based on cached values for _, fr := range f.fronts { for _, cf := range cachedFronts { - sameFront := cf.ProviderID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress() + sameFront := cf.providerID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress() cachedValueFresh := now.Sub(fr.lastSucceeded()) < f.maxAllowedCachedAge if sameFront && cachedValueFresh { fr.setLastSucceeded(cf.LastSucceeded) diff --git a/cache_test.go b/cache_test.go index 635c65f..a37bdb5 100644 --- a/cache_test.go +++ b/cache_test.go @@ -44,9 +44,9 @@ func TestCaching(t *testing.T) { } now := time.Now() - mb := &front{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, ProviderID: testProviderID} - mc := &front{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, ProviderID: ""} // defaulted - md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, ProviderID: "sadcloud"} // skipped + mb := &front{Masquerade: Masquerade{Domain: "b", IpAddress: "2"}, LastSucceeded: now, providerID: testProviderID} + mc := &front{Masquerade: Masquerade{Domain: "c", IpAddress: "3"}, LastSucceeded: now, providerID: ""} // defaulted + md := &front{Masquerade: Masquerade{Domain: "d", IpAddress: "4"}, LastSucceeded: now, providerID: "sadcloud"} // skipped f := makeFronted() @@ -80,7 +80,7 @@ func TestCaching(t *testing.T) { for i, expected := range []*front{mb, mc, md} { require.Equal(t, expected.Domain, masquerades[i].Domain, "Wrong masquerade at position %d", i) require.Equal(t, expected.IpAddress, masquerades[i].IpAddress, "Masquerade at position %d has wrong IpAddress", 0) - require.Equal(t, expected.ProviderID, masquerades[i].ProviderID, "Masquerade at position %d has wrong ProviderID", 0) + require.Equal(t, expected.providerID, masquerades[i].providerID, "Masquerade at position %d has wrong ProviderID", 0) require.Equal(t, now.Unix(), masquerades[i].LastSucceeded.Unix(), "Masquerade at position %d has wrong LastSucceeded", 0) } f.Close() diff --git a/connected_roundtripper.go b/connected_roundtripper.go index d831fb1..3cb80d4 100644 --- a/connected_roundtripper.go +++ b/connected_roundtripper.go @@ -2,10 +2,8 @@ package fronted import ( "fmt" - "io" "net" "net/http" - "net/url" "strings" "time" @@ -40,17 +38,13 @@ func (crt connectedRoundTripper) RoundTrip(req *http.Request) (*http.Response, e // so it is returned as good. crt.Conn.Close() crt.front.markWithResult(true) - err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent for %s", - crt.front.getProviderID(), originHost) + err := fmt.Errorf("no domain fronting mapping for '%s'. Please add it to provider_map.yaml or equivalent", originHost) op.FailIf(err) return nil, err } - log.Debugf("Translated origin %s -> %s for provider %s...", originHost, frontedHost, crt.front.getProviderID()) + log.Debugf("Translated origin %s -> %s.", originHost, frontedHost) - reqi, err := withDomainFront(req, frontedHost, req.Body) - if err != nil { - return nil, op.FailIf(log.Errorf("Failed to copy http request with origin translated to %v?: %v", frontedHost, err)) - } + reqi := withDomainFront(req, frontedHost) disableKeepAlives := true if strings.EqualFold(reqi.Header.Get("Connection"), "upgrade") { disableKeepAlives = false @@ -79,56 +73,26 @@ func (crt connectedRoundTripper) RoundTrip(req *http.Request) (*http.Response, e // connectedConnHTTPTransport uses a preconnected connection to the CDN to make HTTP requests. // This uses the pre-established connection to the CDN on the fronting domain. -func connectedConnHTTPTransport(conn net.Conn, disableKeepAlives bool) http.RoundTripper { - return &connectedTransport{ - Transport: http.Transport{ - Dial: func(network, addr string) (net.Conn, error) { - return conn, nil - }, - TLSHandshakeTimeout: 20 * time.Second, - DisableKeepAlives: disableKeepAlives, - IdleConnTimeout: 70 * time.Second, +func connectedConnHTTPTransport(conn net.Conn, disableKeepAlives bool) *http.Transport { + return &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return conn, nil }, + TLSHandshakeTimeout: 20 * time.Second, + DisableKeepAlives: disableKeepAlives, + IdleConnTimeout: 70 * time.Second, } } -// connectedTransport is a wrapper struct enabling us to modify the protocol of outgoing -// requests to make them all HTTP instead of potentially HTTPS, which breaks our particular -// implemenation of direct domain fronting. -type connectedTransport struct { - http.Transport -} - -func (ct *connectedTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - defer func(op ops.Op) { op.End() }(ops.Begin("direct_transport_roundtrip")) - +func withDomainFront(req *http.Request, frontedHost string) *http.Request { // The connection is already encrypted by domain fronting. We need to rewrite URLs starting // with "https://" to "http://", lest we get an error for doubling up on TLS. // The RoundTrip interface requires that we not modify the memory in the request, so we just // create a copy. - norm := new(http.Request) - *norm = *req // includes shallow copies of maps, but okay - norm.URL = new(url.URL) - *norm.URL = *req.URL - norm.URL.Scheme = "http" - return ct.Transport.RoundTrip(norm) -} + newReq := req.Clone(req.Context()) + newReq.URL.Scheme = "http" + newReq.URL.Host = frontedHost -func withDomainFront(req *http.Request, frontedHost string, body io.ReadCloser) (*http.Request, error) { - urlCopy := *req.URL - urlCopy.Host = frontedHost - r, err := http.NewRequestWithContext(req.Context(), req.Method, urlCopy.String(), body) - if err != nil { - return nil, err - } - - for k, vs := range req.Header { - if !strings.EqualFold(k, "Host") { - v := make([]string, len(vs)) - copy(v, vs) - r.Header[k] = v - } - } - return r, nil + return newReq } diff --git a/front.go b/front.go index 564e882..a9d5161 100644 --- a/front.go +++ b/front.go @@ -85,16 +85,15 @@ type front struct { Masquerade // lastSucceeded: the most recent time at which this Masquerade succeeded LastSucceeded time.Time - // id of DirectProvider that this masquerade is provided by - ProviderID string - mx sync.RWMutex - cacheDirty chan interface{} + providerID string + mx sync.RWMutex + cacheDirty chan interface{} } func newFront(m *Masquerade, providerID string, cacheDirty chan interface{}) Front { return &front{ Masquerade: *m, - ProviderID: providerID, + providerID: providerID, LastSucceeded: time.Time{}, cacheDirty: cacheDirty, } @@ -104,6 +103,9 @@ func (fr *front) dial(rootCAs *x509.CertPool, clientHelloID tls.ClientHelloID) ( ServerName: fr.Domain, RootCAs: rootCAs, } + + // Set a fairly aggressive dial timeout based on observed timeouts running + // pinger in censored regions. dialTimeout := 5 * time.Second addr := fr.IpAddress var sendServerNameExtension bool @@ -176,19 +178,19 @@ func doCheck(client *http.Client, method string, expectedStatus int, u string) b return true } -// getDomain implements MasqueradeInterface. +// getDomain implements Front. func (fr *front) getDomain() string { return fr.Domain } -// getIpAddress implements MasqueradeInterface. +// getIpAddress implements Front. func (fr *front) getIpAddress() string { return fr.IpAddress } -// getProviderID implements MasqueradeInterface. +// getProviderID implements Front. func (fr *front) getProviderID() string { - return fr.ProviderID + return fr.providerID } // MarshalJSON marshals masquerade into json