Skip to content

Commit

Permalink
Merge pull request #1025 from Checkmarx/feature/saraChen/AddRetryToHt…
Browse files Browse the repository at this point in the history
…tpCall

Add retry mechanism to 502 response (AST-82566)
  • Loading branch information
sarahCx authored Feb 2, 2025
2 parents 38ff94d + 4216470 commit d4ef245
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 24 deletions.
19 changes: 19 additions & 0 deletions internal/wrappers/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ var cachedAccessToken string
var cachedAccessTime time.Time
var Domains = make(map[string]struct{})

func retryHTTPRequest(requestFunc func() (*http.Response, error), retries int, baseDelayInMilliSec time.Duration) (*http.Response, error) {

var resp *http.Response
var err error

for attempt := 0; attempt < retries; attempt++ {
resp, err = requestFunc()
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusBadGateway {
return resp, nil
}
_ = resp.Body.Close()
time.Sleep(baseDelayInMilliSec * (1 << attempt))
}
return resp, nil
}

func setAgentName(req *http.Request) {
agentStr := viper.GetString(commonParams.AgentNameKey) + "/" + commonParams.Version
req.Header.Set("User-Agent", agentStr)
Expand Down
80 changes: 80 additions & 0 deletions internal/wrappers/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package wrappers

import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)

type mockReadCloser struct{}

func (m *mockReadCloser) Read(p []byte) (n int, err error) {
return 0, nil
}

func (m *mockReadCloser) Close() error {
return nil
}

func TestRetryHTTPRequest_Success(t *testing.T) {
fn := func() (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &mockReadCloser{},
}, nil
}

resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

func TestRetryHTTPRequest_RetryOnBadGateway(t *testing.T) {
attempts := 0
fn := func() (*http.Response, error) {
attempts++
if attempts < retryAttempts {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &mockReadCloser{},
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: &mockReadCloser{},
}, nil
}

resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, retryAttempts, attempts)
}

func TestRetryHTTPRequest_Fail(t *testing.T) {
fn := func() (*http.Response, error) {
return nil, errors.New("network error")
}

resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
assert.Error(t, err)
assert.Nil(t, resp)
}

func TestRetryHTTPRequest_EndWithBadGateway(t *testing.T) {
fn := func() (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &mockReadCloser{},
}, nil
}

resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
}
7 changes: 7 additions & 0 deletions internal/wrappers/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package wrappers

const (
limitValue = "10000"
retryAttempts = 4
retryDelay = 500
)
47 changes: 37 additions & 10 deletions internal/wrappers/projects-http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"net/http"

"github.com/pkg/errors"
"github.com/spf13/viper"
"net/http"
"time"

errorConstants "github.com/checkmarx/ast-cli/internal/constants/errors"
commonParams "github.com/checkmarx/ast-cli/internal/params"
Expand All @@ -30,7 +30,10 @@ func (p *ProjectsHTTPWrapper) Create(model *Project) (*ProjectResponseModel, *Er
return nil, nil, err
}

resp, err := SendHTTPRequest(http.MethodPost, p.path, bytes.NewBuffer(jsonBytes), true, clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodPost, p.path, bytes.NewBuffer(jsonBytes), true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, nil, err
}
Expand All @@ -49,7 +52,10 @@ func (p *ProjectsHTTPWrapper) Update(projectID string, model *Project) error {
return nil
}

resp, err := SendHTTPRequest(http.MethodPut, fmt.Sprintf("%s/%s", p.path, projectID), bytes.NewBuffer(jsonBytes), true, clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodPut, fmt.Sprintf("%s/%s", p.path, projectID), bytes.NewBuffer(jsonBytes), true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return err
}
Expand Down Expand Up @@ -79,7 +85,10 @@ func (p *ProjectsHTTPWrapper) UpdateConfiguration(projectID string, configuratio
commonParams.ProjectIDFlag: projectID,
}

resp, err := SendHTTPRequestWithQueryParams(http.MethodPatch, "api/configuration/project", params, bytes.NewBuffer(jsonBytes), clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequestWithQueryParams(http.MethodPatch, "api/configuration/project", params, bytes.NewBuffer(jsonBytes), clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, err
}
Expand All @@ -100,7 +109,10 @@ func (p *ProjectsHTTPWrapper) Get(params map[string]string) (
params[limit] = limitValue
}

resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, p.path, params, nil, clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequestWithQueryParams(http.MethodGet, p.path, params, nil, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -137,7 +149,11 @@ func (p *ProjectsHTTPWrapper) GetByID(projectID string) (
*ErrorModel,
error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodGet, p.path+"/"+projectID, http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodGet, p.path+"/"+projectID, http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -180,7 +196,11 @@ func (p *ProjectsHTTPWrapper) GetBranchesByID(projectID string, params map[strin
var request = "/branches?project-id=" + projectID

params["limit"] = limitValue
resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, p.path+request, params, nil, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequestWithQueryParams(http.MethodGet, p.path+request, params, nil, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -215,7 +235,10 @@ func (p *ProjectsHTTPWrapper) GetBranchesByID(projectID string, params map[strin

func (p *ProjectsHTTPWrapper) Delete(projectID string) (*ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodDelete, p.path+"/"+projectID, http.NoBody, true, clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodDelete, p.path+"/"+projectID, http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, err
}
Expand All @@ -232,7 +255,11 @@ func (p *ProjectsHTTPWrapper) Tags() (
*ErrorModel,
error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodGet, p.path+"/tags", http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodGet, p.path+"/tags", http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay*time.Millisecond)
if err != nil {
return nil, nil, err
}
Expand Down
44 changes: 35 additions & 9 deletions internal/wrappers/scans-http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"net/http"

commonParams "github.com/checkmarx/ast-cli/internal/params"
"github.com/pkg/errors"
"github.com/spf13/viper"
"net/http"
)

const (
Expand All @@ -35,7 +34,11 @@ func (s *ScansHTTPWrapper) Create(model *Scan) (*ScanResponseModel, *ErrorModel,
if err != nil {
return nil, nil, err
}
resp, err := SendHTTPRequest(http.MethodPost, s.path, bytes.NewBuffer(jsonBytes), true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodPost, s.path, bytes.NewBuffer(jsonBytes), true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, nil, err
}
Expand All @@ -49,7 +52,11 @@ func (s *ScansHTTPWrapper) Create(model *Scan) (*ScanResponseModel, *ErrorModel,

func (s *ScansHTTPWrapper) Get(params map[string]string) (*ScansCollectionResponseModel, *ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequestWithQueryParams(http.MethodGet, s.path, params, nil, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequestWithQueryParams(http.MethodGet, s.path, params, nil, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -85,7 +92,11 @@ func (s *ScansHTTPWrapper) Get(params map[string]string) (*ScansCollectionRespon

func (s *ScansHTTPWrapper) GetByID(scanID string) (*ScanResponseModel, *ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodGet, s.path+"/"+scanID, http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodGet, s.path+"/"+scanID, http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, nil, err
}
Expand All @@ -100,7 +111,11 @@ func (s *ScansHTTPWrapper) GetByID(scanID string) (*ScanResponseModel, *ErrorMod
func (s *ScansHTTPWrapper) GetWorkflowByID(scanID string) ([]*ScanTaskResponseModel, *ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
path := fmt.Sprintf("%s/%s/workflow", s.path, scanID)
resp, err := SendHTTPRequest(http.MethodGet, path, http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodGet, path, http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -141,7 +156,11 @@ func handleWorkflowResponseWithBody(resp *http.Response, err error) ([]*ScanTask

func (s *ScansHTTPWrapper) Delete(scanID string) (*ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodDelete, s.path+"/"+scanID, http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodDelete, s.path+"/"+scanID, http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, err
}
Expand All @@ -162,7 +181,10 @@ func (s *ScansHTTPWrapper) Cancel(scanID string) (*ErrorModel, error) {
return nil, err
}

resp, err := SendHTTPRequest(http.MethodPatch, s.path+"/"+scanID, bytes.NewBuffer(b), true, clientTimeout)
fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodPatch, s.path+"/"+scanID, bytes.NewBuffer(b), true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, err
}
Expand All @@ -176,7 +198,11 @@ func (s *ScansHTTPWrapper) Cancel(scanID string) (*ErrorModel, error) {

func (s *ScansHTTPWrapper) Tags() (map[string][]string, *ErrorModel, error) {
clientTimeout := viper.GetUint(commonParams.ClientTimeoutKey)
resp, err := SendHTTPRequest(http.MethodGet, s.path+"/tags", http.NoBody, true, clientTimeout)

fn := func() (*http.Response, error) {
return SendHTTPRequest(http.MethodGet, s.path+"/tags", http.NoBody, true, clientTimeout)
}
resp, err := retryHTTPRequest(fn, retryAttempts, retryDelay)
if err != nil {
return nil, nil, err
}
Expand Down
5 changes: 0 additions & 5 deletions internal/wrappers/wrapper-constants.go

This file was deleted.

0 comments on commit d4ef245

Please sign in to comment.