diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index 65675211f..c1e7ade53 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -70,6 +70,25 @@ var cachedAccessToken string var cachedAccessTime time.Time var Domains = make(map[string]struct{}) +func retryHTTPRequest(requestFunc func() (*http.Response, error), retries int, baseDelayInMilliSec time.Duration) (*http.Response, error) { + + var resp *http.Response + var err error + + for attempt := 0; attempt < retries; attempt++ { + resp, err = requestFunc() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusBadGateway { + return resp, nil + } + _ = resp.Body.Close() + time.Sleep(baseDelayInMilliSec * (1 << attempt)) + } + return resp, nil +} + func setAgentName(req *http.Request) { agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version req.Header.Set("User-Agent", agentStr) diff --git a/internal/wrappers/client_test.go b/internal/wrappers/client_test.go new file mode 100644 index 000000000..f79577ce3 --- /dev/null +++ b/internal/wrappers/client_test.go @@ -0,0 +1,80 @@ +package wrappers + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" + "time" +) + +type mockReadCloser struct{} + +func (m *mockReadCloser) Read(p []byte) (n int, err error) { + return 0, nil +} + +func (m *mockReadCloser) Close() error { + return nil +} + +func TestRetryHTTPRequest_Success(t *testing.T) { + fn := func() (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{}, + }, nil + } + + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestRetryHTTPRequest_RetryOnBadGateway(t *testing.T) { + attempts := 0 + fn := func() (*http.Response, error) { + attempts++ + if attempts < retryAttempts { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: &mockReadCloser{}, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{}, + }, nil + } + + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, retryAttempts, attempts) +} + +func TestRetryHTTPRequest_Fail(t *testing.T) { + fn := func() (*http.Response, error) { + return nil, errors.New("network error") + } + + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) + assert.Error(t, err) + assert.Nil(t, resp) +} + +func TestRetryHTTPRequest_EndWithBadGateway(t *testing.T) { + fn := func() (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: &mockReadCloser{}, + }, nil + } + + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) +} diff --git a/internal/wrappers/constants.go b/internal/wrappers/constants.go new file mode 100644 index 000000000..7c491e84a --- /dev/null +++ b/internal/wrappers/constants.go @@ -0,0 +1,7 @@ +package wrappers + +const ( + limitValue = "10000" + retryAttempts = 4 + retryDelay = 500 +) diff --git a/internal/wrappers/projects-http.go b/internal/wrappers/projects-http.go index 8ed7bbd60..556444b6f 100644 --- a/internal/wrappers/projects-http.go +++ b/internal/wrappers/projects-http.go @@ -4,10 +4,10 @@ import ( "bytes" "encoding/json" "fmt" - "net/http" - "github.com/pkg/errors" "github.com/spf13/viper" + "net/http" + "time" errorConstants "github.com/checkmarx/ast-cli/internal/constants/errors" commonParams "github.com/checkmarx/ast-cli/internal/params" @@ -30,7 +30,10 @@ func (p *ProjectsHTTPWrapper) Create(model *Project) (*ProjectResponseModel, *Er return nil, nil, err } - resp, err := SendHTTPRequest(http.MethodPost, p.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodPost, p.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, nil, err } @@ -49,7 +52,10 @@ func (p *ProjectsHTTPWrapper) Update(projectID string, model *Project) error { return nil } - resp, err := SendHTTPRequest(http.MethodPut, fmt.Sprintf("%s/%s", p.path, projectID), bytes.NewBuffer(jsonBytes), true, clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodPut, fmt.Sprintf("%s/%s", p.path, projectID), bytes.NewBuffer(jsonBytes), true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return err } @@ -79,7 +85,10 @@ func (p *ProjectsHTTPWrapper) UpdateConfiguration(projectID string, configuratio commonParams.ProjectIDFlag: projectID, } - resp, err := SendHTTPRequestWithQueryParams(http.MethodPatch, "api/configuration/project", params, bytes.NewBuffer(jsonBytes), clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequestWithQueryParams(http.MethodPatch, "api/configuration/project", params, bytes.NewBuffer(jsonBytes), clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, err } @@ -100,7 +109,10 @@ func (p *ProjectsHTTPWrapper) Get(params map[string]string) ( params[limit] = limitValue } - resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, p.path, params, nil, clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequestWithQueryParams(http.MethodGet, p.path, params, nil, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, nil, err } @@ -137,7 +149,11 @@ func (p *ProjectsHTTPWrapper) GetByID(projectID string) ( *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodGet, p.path+"/"+projectID, http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodGet, p.path+"/"+projectID, http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, nil, err } @@ -180,7 +196,11 @@ func (p *ProjectsHTTPWrapper) GetBranchesByID(projectID string, params map[strin var request = "/branches?project-id=" + projectID params["limit"] = limitValue - resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, p.path+request, params, nil, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequestWithQueryParams(http.MethodGet, p.path+request, params, nil, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, nil, err } @@ -215,7 +235,10 @@ func (p *ProjectsHTTPWrapper) GetBranchesByID(projectID string, params map[strin func (p *ProjectsHTTPWrapper) Delete(projectID string) (*ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodDelete, p.path+"/"+projectID, http.NoBody, true, clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodDelete, p.path+"/"+projectID, http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, err } @@ -232,7 +255,11 @@ func (p *ProjectsHTTPWrapper) Tags() ( *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodGet, p.path+"/tags", http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodGet, p.path+"/tags", http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond) if err != nil { return nil, nil, err } diff --git a/internal/wrappers/scans-http.go b/internal/wrappers/scans-http.go index fc65bedb1..c16ec0a07 100644 --- a/internal/wrappers/scans-http.go +++ b/internal/wrappers/scans-http.go @@ -4,11 +4,10 @@ import ( "bytes" "encoding/json" "fmt" - "net/http" - commonParams "github.com/checkmarx/ast-cli/internal/params" "github.com/pkg/errors" "github.com/spf13/viper" + "net/http" ) const ( @@ -35,7 +34,11 @@ func (s *ScansHTTPWrapper) Create(model *Scan) (*ScanResponseModel, *ErrorModel, if err != nil { return nil, nil, err } - resp, err := SendHTTPRequest(http.MethodPost, s.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodPost, s.path, bytes.NewBuffer(jsonBytes), true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, nil, err } @@ -49,7 +52,11 @@ func (s *ScansHTTPWrapper) Create(model *Scan) (*ScanResponseModel, *ErrorModel, func (s *ScansHTTPWrapper) Get(params map[string]string) (*ScansCollectionResponseModel, *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, s.path, params, nil, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequestWithQueryParams(http.MethodGet, s.path, params, nil, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, nil, err } @@ -85,7 +92,11 @@ func (s *ScansHTTPWrapper) Get(params map[string]string) (*ScansCollectionRespon func (s *ScansHTTPWrapper) GetByID(scanID string) (*ScanResponseModel, *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodGet, s.path+"/"+scanID, http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodGet, s.path+"/"+scanID, http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, nil, err } @@ -100,7 +111,11 @@ func (s *ScansHTTPWrapper) GetByID(scanID string) (*ScanResponseModel, *ErrorMod func (s *ScansHTTPWrapper) GetWorkflowByID(scanID string) ([]*ScanTaskResponseModel, *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) path := fmt.Sprintf("%s/%s/workflow", s.path, scanID) - resp, err := SendHTTPRequest(http.MethodGet, path, http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodGet, path, http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, nil, err } @@ -141,7 +156,11 @@ func handleWorkflowResponseWithBody(resp *http.Response, err error) ([]*ScanTask func (s *ScansHTTPWrapper) Delete(scanID string) (*ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodDelete, s.path+"/"+scanID, http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodDelete, s.path+"/"+scanID, http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, err } @@ -162,7 +181,10 @@ func (s *ScansHTTPWrapper) Cancel(scanID string) (*ErrorModel, error) { return nil, err } - resp, err := SendHTTPRequest(http.MethodPatch, s.path+"/"+scanID, bytes.NewBuffer(b), true, clientTimeout) + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodPatch, s.path+"/"+scanID, bytes.NewBuffer(b), true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, err } @@ -176,7 +198,11 @@ func (s *ScansHTTPWrapper) Cancel(scanID string) (*ErrorModel, error) { func (s *ScansHTTPWrapper) Tags() (map[string][]string, *ErrorModel, error) { clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey) - resp, err := SendHTTPRequest(http.MethodGet, s.path+"/tags", http.NoBody, true, clientTimeout) + + fn := func() (*http.Response, error) { + return SendHTTPRequest(http.MethodGet, s.path+"/tags", http.NoBody, true, clientTimeout) + } + resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay) if err != nil { return nil, nil, err } diff --git a/internal/wrappers/wrapper-constants.go b/internal/wrappers/wrapper-constants.go deleted file mode 100644 index 242bebd6c..000000000 --- a/internal/wrappers/wrapper-constants.go +++ /dev/null @@ -1,5 +0,0 @@ -package wrappers - -const ( - limitValue = "10000" -)