Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor OAuth2 Authorization Request Validation Logic for Improved Readability, Reusability, and Error Handling #273

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
15 changes: 15 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,18 @@ func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
return false
}
}

// AuthorizeRequestMethod the type of authorization request method
type AuthorizeRequestMethod string

const (
AuthorizeRequestGet AuthorizeRequestMethod = "GET"
AuthorizeRequestPost AuthorizeRequestMethod = "POST"
)

func (ar AuthorizeRequestMethod) String() string {
if ar == AuthorizeRequestGet || ar == AuthorizeRequestPost {
return string(ar)
}
return ""
}
17 changes: 13 additions & 4 deletions errors/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,59 @@ func (r *Response) SetHeader(key, value string) {
// https://tools.ietf.org/html/rfc6749#section-5.2
var (
ErrInvalidRequest = errors.New("invalid_request")
ErrMissingClientID = errors.New("missing_client_id")
ErrInvalidRequestMethod = errors.New("invalid_request_method")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrMissingResponseType = errors.New("missing_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrCodeChallengeRquired = errors.New("invalid_request")
ErrCodeChallengeRequired = errors.New("invalid_request")
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
)

// Descriptions error description
var Descriptions = map[error]string{
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrMissingClientID: "The request is missing client_id",
ErrInvalidRequestMethod: "The request method is invalid, unknown, or malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrMissingResponseType: "The requested response type is empty",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
ErrCodeChallengeRequired: "PKCE is required. code_challenge is missing",
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 characters long",
}

// StatusCodes response error HTTP status code
var StatusCodes = map[error]int{
ErrInvalidRequest: 400,
ErrMissingClientID: 400,
ErrInvalidRequestMethod: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrMissingResponseType: 400,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrCodeChallengeRquired: 400,
ErrCodeChallengeRequired: 400,
ErrUnsupportedCodeChallengeMethod: 400,
ErrInvalidCodeChallengeLen: 400,
}
17 changes: 11 additions & 6 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (

// Config configuration parameters
type Config struct {
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
ForcePKCE bool
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
AllowedAuthorizeRequestMethods []oauth2.AuthorizeRequestMethod //allowed `authorize request methods`
ForcePKCE bool
}

// NewConfig create to configuration instance
Expand All @@ -32,6 +33,10 @@ func NewConfig() *Config {
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
AllowedAuthorizeRequestMethods: []oauth2.AuthorizeRequestMethod{
oauth2.AuthorizeRequestGet,
oauth2.AuthorizeRequestPost,
},
}
}

Expand Down
88 changes: 58 additions & 30 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,19 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface
return u.String(), nil
}

// CheckResponseType check allows response type
func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
for _, art := range s.Config.AllowedResponseTypes {
if art == rt {
return true
// CheckResponseType checks for an allowed response type
func (s *Server) CheckResponseType(responseType oauth2.ResponseType) error {
if responseType.String() == "" {
return errors.ErrMissingResponseType
}

for _, rType := range s.Config.AllowedResponseTypes {
if rType == responseType {
return nil
}
}
return false

return errors.ErrUnsupportedResponseType
}

// CheckCodeChallengeMethod checks for allowed code challenge method
Expand All @@ -163,50 +168,73 @@ func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
return false
}

// CheckAuthorizeRequestMethod checks for allowed code challenge method
func (s *Server) CheckAuthorizeRequestMethod(requestMethod oauth2.AuthorizeRequestMethod) bool {
for _, method := range s.Config.AllowedAuthorizeRequestMethods {
if method == requestMethod {
return true
}
}
return false
}

// CheckCodeChallenge checks if the Code Challenge is valid
func (s *Server) CheckCodeChallenge(codeChallenge string, isForcePKCE bool) error {
if isForcePKCE && codeChallenge == "" {
return errors.ErrCodeChallengeRequired
}
if len(codeChallenge) > 0 && len(codeChallenge) < 43 || len(codeChallenge) > 128 {
return errors.ErrInvalidCodeChallengeLen
}
return nil
}

// ValidationAuthorizeRequest the authorization request validation
func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
if r == nil {
return nil, errors.ErrInvalidRequest
}

redirectURI := r.FormValue("redirect_uri")

clientID := r.FormValue("client_id")
if !(r.Method == "GET" || r.Method == "POST") ||
clientID == "" {
return nil, errors.ErrInvalidRequest
if clientID == "" {
return nil, errors.ErrMissingClientID
}

resType := oauth2.ResponseType(r.FormValue("response_type"))
if resType.String() == "" {
return nil, errors.ErrUnsupportedResponseType
} else if allowed := s.CheckResponseType(resType); !allowed {
return nil, errors.ErrUnauthorizedClient
if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed {
return nil, errors.ErrInvalidRequestMethod
}

cc := r.FormValue("code_challenge")
if cc == "" && s.Config.ForcePKCE {
return nil, errors.ErrCodeChallengeRquired
responseType := oauth2.ResponseType(r.FormValue("response_type"))
if err := s.CheckResponseType(responseType); err != nil {
return nil, err
}
if cc != "" && (len(cc) < 43 || len(cc) > 128) {
return nil, errors.ErrInvalidCodeChallengeLen

codeChallenge := r.FormValue("code_challenge")
if err := s.CheckCodeChallenge(codeChallenge, s.Config.ForcePKCE); err != nil {
return nil, err
}

ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
// set default
if ccm == "" {
ccm = oauth2.CodeChallengePlain
codeChallengeMethod := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
// Default to plain method if not specified
if codeChallengeMethod == "" {
codeChallengeMethod = oauth2.CodeChallengePlain
}
if ccm != "" && !s.CheckCodeChallengeMethod(ccm) {
if !s.CheckCodeChallengeMethod(codeChallengeMethod) {
return nil, errors.ErrUnsupportedCodeChallengeMethod
}

req := &AuthorizeRequest{
return &AuthorizeRequest{
RedirectURI: redirectURI,
ResponseType: resType,
ResponseType: responseType,
ClientID: clientID,
State: r.FormValue("state"),
Scope: r.FormValue("scope"),
Request: r,
CodeChallenge: cc,
CodeChallengeMethod: ccm,
}
return req, nil
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}, nil
}

// GetAuthorizeToken get authorization token(code)
Expand Down