Skip to content

Commit

Permalink
utils.go: fixUrl(): Check the Location value before return
Browse files Browse the repository at this point in the history
  • Loading branch information
leiless committed Jan 22, 2021
1 parent aa96f1a commit 59bcbe2
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math/rand"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
Expand Down Expand Up @@ -125,7 +126,7 @@ func isContentType(contentType string, h *http.Header) bool {
// see:
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/
// https://medium.com/@nate510/don-t-use-go-s-default-http-client-4804cb19f779
func getUrlContent(url, contentType string, bootstrap []string, timeout time.Duration) (string, error) {
func getUrlContent(theUrl, contentType string, bootstrap []string, timeout time.Duration) (string, error) {
var transport http.RoundTripper

if len(bootstrap) != 0 {
Expand Down Expand Up @@ -156,7 +157,7 @@ func getUrlContent(url, contentType string, bootstrap []string, timeout time.Dur
// Fallback to use system default resolvers, which located at /etc/resolv.conf
}

req, err := http.NewRequest(http.MethodGet, url, nil)
req, err := http.NewRequest(http.MethodGet, theUrl, nil)
if err != nil {
return "", err
}
Expand All @@ -178,10 +179,10 @@ func getUrlContent(url, contentType string, bootstrap []string, timeout time.Dur
}

if len(contentType) != 0 && !isContentType(contentType, &resp.Header) {
if url, err := fixUrl(url, resp.Header); err != nil {
if theUrl, err = fixUrl(theUrl, resp.Header); err != nil {
return "", err
} else {
return getUrlContent(url, contentType, bootstrap, timeout)
return getUrlContent(theUrl, contentType, bootstrap, timeout)
}
}

Expand All @@ -193,13 +194,16 @@ func getUrlContent(url, contentType string, bootstrap []string, timeout time.Dur
return string(content), nil
}

func fixUrl(url string, h http.Header) (string, error) {
func fixUrl(theUrl string, h http.Header) (string, error) {
const LocationKey = "Location"
location := h.Get(LocationKey)
if location != "" {
if _, err := url.Parse(theUrl); err != nil {
return "", fmt.Errorf("fixUrl(): url.Parse(): %w", err)
}
return location, nil
}
return "", fmt.Errorf("%q header key not found in %v", LocationKey, url)
return "", fmt.Errorf("%q header key not found in %v", LocationKey, theUrl)
}

func stringHash(str string) uint64 {
Expand Down

0 comments on commit 59bcbe2

Please sign in to comment.