Skip to content

Commit

Permalink
fix: use http client instead of transport
Browse files Browse the repository at this point in the history
  • Loading branch information
smrz2001 committed Nov 27, 2024
1 parent 1bfa322 commit eaf4b12
Showing 1 changed file with 78 additions and 28 deletions.
106 changes: 78 additions & 28 deletions controllers/proxy_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
Expand All @@ -27,13 +26,13 @@ type ProxyController interface {
}

type proxyController struct {
ctx context.Context
cfg *config.Config
logger logging.Logger
metrics metric.MetricService
target *url.URL
mirror *url.URL
transport *http.Transport
ctx context.Context
cfg *config.Config
logger logging.Logger
metrics metric.MetricService
target *url.URL
mirror *url.URL
client *http.Client
}

type requestType string
Expand Down Expand Up @@ -70,27 +69,62 @@ func NewProxyController(
logger.Fatalf("invalid mirror URL: %v", err)
}
}
// Configure transport with timeouts
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: cfg.Proxy.DialTimeout,
}).DialContext,
ForceAttemptHTTP2: true,
IdleConnTimeout: cfg.Proxy.IdleTimeout,

client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects")
}
return nil
},
}

return &proxyController{
ctx: ctx,
cfg: cfg,
logger: logger,
metrics: metrics,
target: target,
mirror: mirror,
transport: transport,
ctx: ctx,
cfg: cfg,
logger: logger,
metrics: metrics,
target: target,
mirror: mirror,
client: client,
}
}

func (_this proxyController) createRequest(
ctx context.Context,
originalReq *http.Request,
bodyBytes []byte,
) (*http.Request, error) {
// Create new request with appropriate context
newReq, err := http.NewRequestWithContext(
ctx,
originalReq.Method,
originalReq.URL.String(),
bytes.NewBuffer(bodyBytes),
)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// Copy headers from original request
for k, vv := range originalReq.Header {
for _, v := range vv {
newReq.Header.Add(k, v)
}
}

// Add Content-Length header if body exists
if len(bodyBytes) > 0 {
newReq.Header.Add("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
// Ensure content type is preserved
if contentType := originalReq.Header.Get("Content-Type"); contentType != "" {
newReq.Header.Set("Content-Type", contentType)
}
}

return newReq, nil
}

func (_this proxyController) proxyRequest(c *gin.Context) {
// Read the original request body
bodyBytes, err := io.ReadAll(c.Request.Body)
Expand All @@ -99,14 +133,24 @@ func (_this proxyController) proxyRequest(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read request"})
return
}
// Ignore error since we are closing the body anyway
_ = c.Request.Body.Close()

// Restore the request body for downstream middleware/handlers
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))

// Create proxy request
proxyReq, err := _this.createRequest(c.Request.Context(), c.Request, bodyBytes)
if err != nil {
_this.logger.Errorw("failed to create proxy request", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
return
}

// Handle proxy request
_this.handleRequest(requestContext{
reqType: proxyRequest,
ginContext: c,
request: c.Request,
request: proxyReq,
bodyBytes: bodyBytes,
startTime: time.Now(),
targetURL: _this.target,
Expand All @@ -118,9 +162,15 @@ func (_this proxyController) proxyRequest(c *gin.Context) {
ctx, cancel := context.WithTimeout(_this.ctx, _this.cfg.Proxy.MirrorTimeout)
defer cancel()

mirrorReq, err := _this.createRequest(ctx, c.Request, bodyBytes)
if err != nil {
_this.logger.Errorw("failed to create mirror request", "error", err)
return
}

_this.handleRequest(requestContext{
reqType: mirrorRequest,
request: c.Request.Clone(ctx),
request: mirrorReq,
bodyBytes: bodyBytes,
startTime: time.Now(),
targetURL: _this.mirror,
Expand All @@ -131,7 +181,7 @@ func (_this proxyController) proxyRequest(c *gin.Context) {

func (_this proxyController) handleRequest(reqCtx requestContext) {
// Prepare the request
req := reqCtx.request.Clone(reqCtx.request.Context())
req := reqCtx.request
req.URL.Scheme = reqCtx.targetURL.Scheme
req.URL.Host = reqCtx.targetURL.Host
req.Host = reqCtx.targetURL.Host
Expand All @@ -155,7 +205,7 @@ func (_this proxyController) handleRequest(reqCtx requestContext) {
}

// Make the request
resp, err := _this.transport.RoundTrip(req)
resp, err := _this.client.Do(req)
if err != nil {
_this.logger.Errorw(fmt.Sprintf("%s error", reqCtx.reqType), "error", err)
if reqCtx.reqType == proxyRequest {
Expand Down

0 comments on commit eaf4b12

Please sign in to comment.