diff --git a/aws/endpoints/v3model.go b/aws/endpoints/v3model.go index eb2ac83c992..773613722f4 100644 --- a/aws/endpoints/v3model.go +++ b/aws/endpoints/v3model.go @@ -7,6 +7,8 @@ import ( "strings" ) +var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`) + type partitions []partition func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) { @@ -124,7 +126,7 @@ func (p partition) EndpointFor(service, region string, opts ...func(*Options)) ( defs := []endpoint{p.Defaults, s.Defaults} - return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt), nil + return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt) } func serviceList(ss services) []string { @@ -233,7 +235,7 @@ func getByPriority(s []string, p []string, def string) string { return s[0] } -func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint { +func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) { var merged endpoint for _, def := range defs { merged.mergeIn(def) @@ -260,6 +262,10 @@ func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs [ region = signingRegion } + if !validateInputRegion(region) { + return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided") + } + u := strings.Replace(hostname, "{service}", service, 1) u = strings.Replace(u, "{region}", region, 1) u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1) @@ -274,7 +280,7 @@ func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs [ SigningName: signingName, SigningNameDerived: signingNameDerived, SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner), - } + }, nil } func getEndpointScheme(protocols []string, disableSSL bool) string { @@ -339,3 +345,7 @@ const ( boxedFalse boxedTrue ) + +func validateInputRegion(region string) bool { + return regionValidationRegex.MatchString(region) +} diff --git a/aws/endpoints/v3model_test.go b/aws/endpoints/v3model_test.go index 28b27aeef4c..b4a63683f7a 100644 --- a/aws/endpoints/v3model_test.go +++ b/aws/endpoints/v3model_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "reflect" "regexp" + "strconv" "strings" "testing" ) @@ -209,9 +210,12 @@ func TestEndpointResolve(t *testing.T) { SSLCommonName: "new sslCommonName", } - resolved := e.resolve("service", "partitionID", "region", "dnsSuffix", + resolved, err := e.resolve("service", "partitionID", "region", "dnsSuffix", defs, Options{}, ) + if err != nil { + t.Errorf("expected no error, got %v", err) + } if e, a := "https://service.region.dnsSuffix", resolved.URL; e != a { t.Errorf("expect %v, got %v", e, a) @@ -225,6 +229,14 @@ func TestEndpointResolve(t *testing.T) { if e, a := "v4", resolved.SigningMethod; e != a { t.Errorf("expect %v, got %v", e, a) } + + // Check Invalid Region Identifier Format + _, err = e.resolve("service", "partitionID", "notvalid.com", "dnsSuffix", + defs, Options{}, + ) + if err == nil { + t.Errorf("expected err, got nil") + } } func TestEndpointMergeIn(t *testing.T) { @@ -598,3 +610,39 @@ func TestEndpointFor_EmptyRegion(t *testing.T) { }) } } + +func TestRegionValidator(t *testing.T) { + cases := []struct { + Region string + Valid bool + }{ + 0: { + Region: "us-east-1", + Valid: true, + }, + 1: { + Region: "invalid.com", + Valid: false, + }, + 2: { + Region: "@invalid.com/%23", + Valid: false, + }, + 3: { + Region: "local", + Valid: true, + }, + 4: { + Region: "9-west-1", + Valid: true, + }, + } + + for i, tt := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + if e, a := tt.Valid, validateInputRegion(tt.Region); e != a { + t.Errorf("expected %v, got %v", e, a) + } + }) + } +}