diff --git a/src/cmd/cli/command/version.go b/src/cmd/cli/command/version.go index db6235729..fbca10adc 100644 --- a/src/cmd/cli/command/version.go +++ b/src/cmd/cli/command/version.go @@ -4,14 +4,12 @@ import ( "context" "encoding/json" "errors" - "net/http" "strings" + "github.com/DefangLabs/defang/src/pkg/http" "golang.org/x/mod/semver" ) -var httpClient = http.DefaultClient - func isNewer(current, comparand string) bool { version, ok := normalizeVersion(current) if !ok { @@ -38,16 +36,12 @@ func GetCurrentVersion() string { } func GetLatestVersion(ctx context.Context) (string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/repos/DefangLabs/defang/releases/latest", nil) - if err != nil { - return "", err - } - resp, err := httpClient.Do(req) + resp, err := http.GetWithContext(ctx, "https://api.github.com/repos/DefangLabs/defang/releases/latest") if err != nil { return "", err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != 200 { // The primary rate limit for unauthenticated requests is 60 requests per hour, per IP. return "", errors.New(resp.Status) } diff --git a/src/cmd/cli/command/version_test.go b/src/cmd/cli/command/version_test.go index f072602a7..628103818 100644 --- a/src/cmd/cli/command/version_test.go +++ b/src/cmd/cli/command/version_test.go @@ -6,6 +6,8 @@ import ( "net/http" "net/http/httptest" "testing" + + ourHttp "github.com/DefangLabs/defang/src/pkg/http" ) func TestIsNewer(t *testing.T) { @@ -75,8 +77,8 @@ func TestGetLatestVersion(t *testing.T) { rec.Header().Add("Content-Type", "application/json") rec.WriteString(fmt.Sprintf(`{"tag_name":"%v"}`, version)) - httpClient = &http.Client{Transport: &mockRoundTripper{ - method: http.MethodGet, + ourHttp.DefaultClient = &http.Client{Transport: &mockRoundTripper{ + method: "GET", url: "https://api.github.com/repos/DefangLabs/defang/releases/latest", resp: rec.Result(), }} diff --git a/src/go.mod b/src/go.mod index 39bd8accf..2a9a2eebf 100644 --- a/src/go.mod +++ b/src/go.mod @@ -21,6 +21,7 @@ require ( github.com/digitalocean/godo v1.111.0 github.com/docker/docker v25.0.5+incompatible github.com/google/uuid v1.6.0 + github.com/hashicorp/go-retryablehttp v0.7.7 github.com/miekg/dns v1.1.59 github.com/moby/patternmatcher v0.6.0 github.com/muesli/termenv v0.15.2 @@ -43,7 +44,6 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/hashicorp/go-retryablehttp v0.7.7 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/rivo/uniseg v0.2.0 // indirect diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go new file mode 100644 index 000000000..0b5362fb1 --- /dev/null +++ b/src/pkg/http/client.go @@ -0,0 +1,20 @@ +package http + +import ( + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/hashicorp/go-retryablehttp" +) + +var DefaultClient = newClient().StandardClient() + +type termLogger struct{} + +func (termLogger) Printf(format string, args ...interface{}) { + term.Debugf(format, args...) +} + +func newClient() *retryablehttp.Client { + c := retryablehttp.NewClient() // default client retries 4 times: 1+2+4+8 = 15s max + c.Logger = termLogger{} + return c +} diff --git a/src/pkg/http/get.go b/src/pkg/http/get.go index 500e26216..b701066b0 100644 --- a/src/pkg/http/get.go +++ b/src/pkg/http/get.go @@ -8,20 +8,20 @@ import ( type Header = http.Header func GetWithContext(ctx context.Context, url string) (*http.Response, error) { - hreq, err := http.NewRequestWithContext(ctx, "GET", url, nil) + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } - return http.DefaultClient.Do(hreq) + return DefaultClient.Do(hreq) } func GetWithHeader(ctx context.Context, url string, header http.Header) (*http.Response, error) { - hreq, err := http.NewRequestWithContext(ctx, "GET", url, nil) + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } hreq.Header = header - return http.DefaultClient.Do(hreq) + return DefaultClient.Do(hreq) } func GetWithAuth(ctx context.Context, url, auth string) (*http.Response, error) { diff --git a/src/pkg/http/post.go b/src/pkg/http/post.go index 0284f822b..d229df4f4 100644 --- a/src/pkg/http/post.go +++ b/src/pkg/http/post.go @@ -3,13 +3,12 @@ package http import ( "fmt" "io" - "net/http" "net/url" ) // PostForValues issues a POST to the specified URL and returns the response body as url.Values. func PostForValues(_url, contentType string, body io.Reader) (url.Values, error) { - resp, err := http.Post(_url, contentType, body) + resp, err := DefaultClient.Post(_url, contentType, body) if err != nil { return nil, err } diff --git a/src/pkg/http/put.go b/src/pkg/http/put.go index 51aa99475..8304b72b4 100644 --- a/src/pkg/http/put.go +++ b/src/pkg/http/put.go @@ -4,7 +4,6 @@ import ( "context" "io" "net/http" - "net/url" ) // Put issues a PUT to the specified URL. @@ -19,19 +18,10 @@ import ( // See the Client.Do method documentation for details on how redirects // are handled. func Put(ctx context.Context, url string, contentType string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, "PUT", url, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body) if err != nil { return nil, err } req.Header.Set("Content-Type", contentType) - return http.DefaultClient.Do(req) -} - -func RemoveQueryParam(qurl string) string { - u, err := url.Parse(qurl) - if err != nil { - return qurl - } - u.RawQuery = "" - return u.String() + return DefaultClient.Do(req) } diff --git a/src/pkg/http/put_test.go b/src/pkg/http/put_test.go index da56cf021..5787cacf4 100644 --- a/src/pkg/http/put_test.go +++ b/src/pkg/http/put_test.go @@ -1,12 +1,37 @@ package http -import "testing" - -func TestRemoveQueryParam(t *testing.T) { - url := "https://example.com/foo?bar=baz" - expected := "https://example.com/foo" - actual := RemoveQueryParam(url) - if actual != expected { - t.Errorf("expected %q, got %q", expected, actual) +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestPutRetries(t *testing.T) { + const body = "test" + calls := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + if calls < 3 { + http.Error(w, "error", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + if b, err := io.ReadAll(r.Body); err != nil || string(b) != body { + t.Error("expected body to be read") + } + })) + t.Cleanup(server.Close) + + resp, err := Put(context.Background(), server.URL, "text/plain", strings.NewReader(body)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + if calls != 3 { + t.Errorf("expected 3 calls, got %d", calls) } } diff --git a/src/pkg/http/query.go b/src/pkg/http/query.go new file mode 100644 index 000000000..789c236f3 --- /dev/null +++ b/src/pkg/http/query.go @@ -0,0 +1,12 @@ +package http + +import "net/url" + +func RemoveQueryParam(qurl string) string { + u, err := url.Parse(qurl) + if err != nil { + return qurl + } + u.RawQuery = "" + return u.String() +} diff --git a/src/pkg/http/query_test.go b/src/pkg/http/query_test.go new file mode 100644 index 000000000..da56cf021 --- /dev/null +++ b/src/pkg/http/query_test.go @@ -0,0 +1,12 @@ +package http + +import "testing" + +func TestRemoveQueryParam(t *testing.T) { + url := "https://example.com/foo?bar=baz" + expected := "https://example.com/foo" + actual := RemoveQueryParam(url) + if actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +}