Skip to content

Commit

Permalink
Add validation to region identifier before setting endpoint (#3330)
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail authored Jun 18, 2020
1 parent 2b1b266 commit ea6cd9c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 4 deletions.
16 changes: 13 additions & 3 deletions aws/endpoints/v3model.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
)

var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`)

type partitions []partition

func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
Expand Down Expand Up @@ -124,7 +126,7 @@ func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (

defs := []endpoint{p.Defaults, s.Defaults}

return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt), nil
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt)
}

func serviceList(ss services) []string {
Expand Down Expand Up @@ -233,7 +235,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0]
}

func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) {
var merged endpoint
for _, def := range defs {
merged.mergeIn(def)
Expand All @@ -260,6 +262,10 @@ func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs [
region = signingRegion
}

if !validateInputRegion(region) {
return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided")
}

u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
Expand All @@ -274,7 +280,7 @@ func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs [
SigningName: signingName,
SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}, nil
}

func getEndpointScheme(protocols []string, disableSSL bool) string {
Expand Down Expand Up @@ -339,3 +345,7 @@ const (
boxedFalse
boxedTrue
)

func validateInputRegion(region string) bool {
return regionValidationRegex.MatchString(region)
}
50 changes: 49 additions & 1 deletion aws/endpoints/v3model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"reflect"
"regexp"
"strconv"
"strings"
"testing"
)
Expand Down Expand Up @@ -209,9 +210,12 @@ func TestEndpointResolve(t *testing.T) {
SSLCommonName: "new sslCommonName",
}

resolved := e.resolve("service", "partitionID", "region", "dnsSuffix",
resolved, err := e.resolve("service", "partitionID", "region", "dnsSuffix",
defs, Options{},
)
if err != nil {
t.Errorf("expected no error, got %v", err)
}

if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
Expand All @@ -225,6 +229,14 @@ func TestEndpointResolve(t *testing.T) {
if e, a := "v4", resolved.SigningMethod; e != a {
t.Errorf("expect %v, got %v", e, a)
}

// Check Invalid Region Identifier Format
_, err = e.resolve("service", "partitionID", "notvalid.com", "dnsSuffix",
defs, Options{},
)
if err == nil {
t.Errorf("expected err, got nil")
}
}

func TestEndpointMergeIn(t *testing.T) {
Expand Down Expand Up @@ -598,3 +610,39 @@ func TestEndpointFor_EmptyRegion(t *testing.T) {
})
}
}

func TestRegionValidator(t *testing.T) {
cases := []struct {
Region string
Valid bool
}{
0: {
Region: "us-east-1",
Valid: true,
},
1: {
Region: "invalid.com",
Valid: false,
},
2: {
Region: "@invalid.com/%23",
Valid: false,
},
3: {
Region: "local",
Valid: true,
},
4: {
Region: "9-west-1",
Valid: true,
},
}

for i, tt := range cases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
if e, a := tt.Valid, validateInputRegion(tt.Region); e != a {
t.Errorf("expected %v, got %v", e, a)
}
})
}
}

0 comments on commit ea6cd9c

Please sign in to comment.