Skip to content

Commit

Permalink
Migrate to jwt-go v5 (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
majst01 authored Nov 6, 2024
1 parent de563b1 commit 1f475dd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 36 deletions.
21 changes: 15 additions & 6 deletions dex.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
)

Expand All @@ -30,6 +30,8 @@ type Dex struct {
algorithmWhitelist []string

userExtractor UserExtractorFn

jwtParserOptions []jwt.ParserOption
}

type keyRsp struct {
Expand Down Expand Up @@ -70,7 +72,7 @@ func (dx *Dex) With(opts ...Option) *Dex {
// is not an array.
type Claims struct {
jwt.RegisteredClaims
Audience interface{} `json:"aud,omitempty"`
Audience any `json:"aud,omitempty"`
Groups []string `json:"groups"`
EMail string `json:"email"`
Name string `json:"name"`
Expand Down Expand Up @@ -99,6 +101,13 @@ func AlgorithmsWhitelist(algNames []string) Option {
}
}

func JWTParserOptions(opt jwt.ParserOption) Option {
return func(dex *Dex) *Dex {
dex.jwtParserOptions = append(dex.jwtParserOptions, opt)
return dex
}
}

func (dx *Dex) algorithmSupported(alg string) bool {
for _, a := range dx.algorithmWhitelist {
if a == alg {
Expand Down Expand Up @@ -166,7 +175,7 @@ func (dx *Dex) updateKeys(old jwk.Set) (jwk.Set, error) {

// searchKey searches the given key in the set loaded from dex. If
// there is a key it will be returned otherwise an error is returned
func (dx *Dex) searchKey(kid string) (interface{}, error) {
func (dx *Dex) searchKey(kid string) (any, error) {
for i := 0; i < 2; i++ {
keys, err := dx.fetchKeys()
if err != nil {
Expand All @@ -177,7 +186,7 @@ func (dx *Dex) searchKey(kid string) (interface{}, error) {
dx.forceUpdate()
continue
}
var key interface{}
var key any
err = jwtkey.Raw(&key)
return key, err
}
Expand All @@ -197,7 +206,7 @@ func (dx *Dex) User(rq *http.Request) (*User, error) {
}
bearerToken := strings.TrimSpace(splitToken[1])

token, err := jwt.ParseWithClaims(bearerToken, &Claims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(bearerToken, &Claims{}, func(token *jwt.Token) (any, error) {
alg, ok := token.Header["alg"].(string)
if !ok {
return nil, errors.New("invalid token")
Expand All @@ -210,7 +219,7 @@ func (dx *Dex) User(rq *http.Request) (*User, error) {
return nil, errors.New("invalid token")
}
return dx.searchKey(kid)
})
}, dx.jwtParserOptions...)
if err != nil {
return nil, err
}
Expand Down
47 changes: 25 additions & 22 deletions dex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v5"
)

var (
Expand All @@ -35,31 +35,31 @@ var (
//nolint:gosec
authtokenAlgHS256Kid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsICJraWQiOiIwMWMwMmQ1ZC1jYWEwLTQ0NTAtYWQ3Ni1iZjk1MDQyMGI4YjYifQ.eyJhdWQiOlsidGhlQXVkaWVuY2UiXSwiZW1haWwiOiJhY2hpbS5hZG1pbkB0ZW5hbnQuZGUiLCJleHAiOjE1ODc3NTMxODYsImZlZGVyYXRlZF9jbGFpbXMiOnsiY29ubmVjdG9yX2lkIjoidG50X2xkYXBfb3BlbmxkYXAiLCJ1c2VyX2lkIjoiY249YWNoaW0uYWRtaW4sb3U9UGVvcGxlLGRjPXRlbmFudCxkYz1kZSJ9LCJncm91cHMiOlsiZ3JwYSIsImdycGIiXSwiaWF0IjoxNTg3NzM4Nzg2LCJpc3MiOiJodHRwczovL2RleC50ZXN0Lm1ldGFsLXN0YWNrLmlvL2RleCIsIm5hbWUiOiJhY2hpbSIsInN1YiI6ImFjaGltIn0.vHRBpA1Jvb6kPLI56xCdIh42or96N5sOHg3cHs-is-o"

dk1 = map[string]interface{}{
dk1 = map[string]any{
"use": "sig",
"kty": "RSA",
"kid": "01c02d5d-caa0-4450-ad76-bf950420b8b6",
"alg": "RS256",
"n": "zFSDsEpZ-EegnJpYFTmaUVz2OvtCQty1gYFxLECICU2lrFCxoAFnkARjbyuvT68sIbhdSZ981YoY_oVohhLOMZjNV3KUhRPlMaSZsEDfnZLOGjfRzjOLNGwtcfu7uLvSVOhaF1bqNUtQHN1ljEmcHWJbJzPFLOBD5uK5tZ-zT0q8NyDRnIB3yNPppk1OpMgmAvxpXaIjsTUfOaOz4vbG6opWg4wz-cLgtyvA1YMSQ24EVnHPC4b2fJOJf9DXf1qkVNjiY9BqO19afv8pM1cliYu66wN4D_eAXQnhA_8j6AQyNkHusaOG1TCzxyPQDtcQYjNZfhQBxXLZE_JM_XdCSAdtwPcQTsySHQHIxsFG3M90DiuukCc7iusAcmCupY5jXTH70_ZykvvaTqxgjavj5zsSndiCwSicrJSoh4YwhqUsZMKivphhyZIb0VpzWRhTYhlN1snC184caa_kPgyRRZux40RxCjluo9Taftm2MUji4BZ_TovUG2IBsOJdp9OdmKT9zuw_feNUL5o3ImCmP4ifI03I3kCATS-KnvNnILQQXYpwP-6hNEZJAXcBtXUnoqMbOdqOjKjNc8ZIBwINe8WVuCZhH2bZc0RK8kh6EgZupMrxAPmmvzfr8RgfU8LKOkQ6Y41UhE7qCkLWARLgQpaRu5HmE_YrZvodqSmY3fU",
"e": "AQAB",
}
dk2 = map[string]interface{}{
dk2 = map[string]any{
"use": "sig",
"kty": "RSA",
"kid": "0f3abdf4e05337b02fa0e36291b9147379dbb686",
"alg": "RS256",
"n": "sUGZtErd2hymWcdHcjkm5bNqVlvMEkVxIabEgWUwWW0mWc2g5QHKysXDS6Oi1Oyzumjx-dmbZ6nz3C_bJMqEbIwRSyxGnUDziraUIs8WAp0bGv440llxhmT26UifOF9TL8iUvRAVKDzCv3YttyxmLojls3c-L9P-71Uc2NmskeBe9lwE5E-1SX2lx01fjhVRrp2TeujqeY7VR4sdKPXyECn7-W7nuOUAQt4ziiGX-gNrt--SX2oG_2TLw_Urv8O9epw8VjB9zWXKsmjkCUVxPAdSHdnlyRQf7TAhiygcK11Fl2ABIv_DwP0Ei5sd-E6FPqfzrNVA81L16mFHaZLciQ",
"e": "AQAB",
}
dk3 = map[string]interface{}{
dk3 = map[string]any{
"use": "sig",
"kty": "RSA",
"kid": "d3519837ce558fa66192a82f925c1169de358d63",
"alg": "RS256",
"n": "qMu9ak2GZVy7mgSG2nqDJAlYBqXCTTbtSTEtAVpYKcCZKRkDY7kWkPrE8rdhuZV0sVN1-5SQivaDtfXSMBBaLpZFbhA0l98fH3ExOpVbdlHNNWd3mSJEcEFc1QGhc755shyFIliOW59JMNzETIF8eq-MXMt8dKxtnUVZWJk8EYOQSxYK7E9cl4HtACIoGHchRrUctIUJBFgSRbKx1u-_Qnf9cnJeSNdXKL8l7bvLtm5UZWPQrUo229pQ687jUKZu-k2Xag3bAsRGJ6ScbWuLBIJdOxNbvnA3XyARxvqIeZAoEFxDn3q6rhyG024MeRhn4Rd_RzeEq2Y0hsa68M7pkw",
"e": "AQAB",
}
dk123 = map[string]interface{}{
dk123 = map[string]any{
"use": "sig",
"kty": "RSA",
"kid": "123",
Expand All @@ -68,20 +68,20 @@ var (
"e": "AQAB",
}

firstkeys = []map[string]interface{}{
firstkeys = []map[string]any{
dk1,
dk2,
}
firstkeydata = map[string]interface{}{
firstkeydata = map[string]any{
"keys": firstkeys,
}
secondkeys = []map[string]interface{}{
secondkeys = []map[string]any{
dk1,
dk2,
dk3,
dk123,
}
secondkeydata = map[string]interface{}{
secondkeydata = map[string]any{
"keys": secondkeys,
}
)
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestDex_keyfetcher(t *testing.T) {
t.Errorf("the keys were not fetched")
return
}
data := [][]map[string]interface{}{firstkeys, secondkeys}
data := [][]map[string]any{firstkeys, secondkeys}
searchkey := dk3["kid"].(string)
// the server will return first "firstkeys" and on the secondcall "secondkeys"
// only the secondkeys contains "dk3", so the following tests if the dex
Expand Down Expand Up @@ -190,53 +190,53 @@ func TestDex_User(t *testing.T) {
name: "token is expired",
token: authtokenAlgRS256,
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "token is expired by 15h59m21s",
err: "token has invalid claims: token is expired",
},
{
name: "token invalid default whitelist - signature algorithm 'none' no kid",
token: authtokenAlgNone,
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
{
name: "token invalid default whitelist - signature algorithm 'none' with kid",
token: authtokenAlgNoneKid,
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
{
name: "token invalid default whitelist - signature algorithm 'HS256' no kid",
token: authtokenAlgHS256,
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
{
name: "token invalid default whitelist - signature algorithm 'HS256' with kid",
token: authtokenAlgHS256Kid,
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
{
name: "algorithm not in whitelist 'RS256' - signature algorithm 'HS256' with kid",
token: authtokenAlgHS256Kid,
opt: AlgorithmsWhitelist([]string{"RS256"}),
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
{
name: "algorithm not in whitelist - empty whitelist",
token: authtokenAlgRS256,
opt: AlgorithmsWhitelist([]string{}),
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "invalid token",
err: "token is unverifiable: error while executing keyfunc: invalid token",
},
}
for _, tt := range test {
tt := tt
t.Run(tt.name, func(t *testing.T) {
jwt.TimeFunc = func() time.Time {
jwtParserOpt := JWTParserOptions(jwt.WithTimeFunc(func() time.Time {
return tt.t
}
}))

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, rq *http.Request) {
err := json.NewEncoder(w).Encode(secondkeydata)
Expand All @@ -253,6 +253,8 @@ func TestDex_User(t *testing.T) {
if tt.opt != nil {
tt.opt(dx)
}
jwtParserOpt(dx)

rq := httptest.NewRequest(http.MethodGet, srv.URL, nil)
rq.Header.Add("Authorization", "Bearer "+tt.token)
usr, err := dx.User(rq)
Expand Down Expand Up @@ -291,15 +293,15 @@ func TestDex_UserWithOptions(t *testing.T) {
{
name: "token is expired",
t: time.Date(2019, time.May, 10, 6, 6, 0, 0, time.UTC),
err: "token is expired by 15h59m21s",
err: "token has invalid claims: token is expired",
},
}
for _, tt := range test {
tt := tt
t.Run(tt.name, func(t *testing.T) {
jwt.TimeFunc = func() time.Time {
jwtParserOpt := JWTParserOptions(jwt.WithTimeFunc(func() time.Time {
return tt.t
}
}))

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, rq *http.Request) {
err := json.NewEncoder(w).Encode(secondkeydata)
Expand All @@ -313,6 +315,7 @@ func TestDex_UserWithOptions(t *testing.T) {
t.Errorf("NewDex() error = %v", err)
return
}
jwtParserOpt(dx)

// change Name to akim and de-prefix groups - just for this test
dx.With(UserExtractor(func(claims *Claims) (user *User, e error) {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ require (
github.com/coreos/go-oidc/v3 v3.11.0
github.com/go-jose/go-jose/v4 v4.0.4
github.com/go-openapi/runtime v0.28.0
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/lestrrat-go/jwx/v2 v2.1.1
github.com/lestrrat-go/jwx/v2 v2.1.2
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.28.0
golang.org/x/net v0.30.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
Expand All @@ -54,8 +54,8 @@ github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCG
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.1.1 h1:Y2ltVl8J6izLYFs54BVcpXLv5msSW4o8eXwnzZLI32E=
github.com/lestrrat-go/jwx/v2 v2.1.1/go.mod h1:4LvZg7oxu6Q5VJwn7Mk/UwooNRnTHUpXBj2C4j3HNx0=
github.com/lestrrat-go/jwx/v2 v2.1.2 h1:6poete4MPsO8+LAEVhpdrNI4Xp2xdiafgl2RD89moBc=
github.com/lestrrat-go/jwx/v2 v2.1.2/go.mod h1:pO+Gz9whn7MPdbsqSJzG8TlEpMZCwQDXnFJ+zsUVh8Y=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
Expand Down
4 changes: 2 additions & 2 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ type providerJSON struct {
}

// CreateToken creates a jwt token with the given claims
func CreateToken(signer jose.Signer, cl interface{}, privateClaims ...interface{}) (string, error) {
func CreateToken(signer jose.Signer, cl any, privateClaims ...any) (string, error) {
builder := jwt.Signed(signer).Claims(cl)
for i := range privateClaims {
builder = builder.Claims(privateClaims[i])
Expand All @@ -204,7 +204,7 @@ func CreateToken(signer jose.Signer, cl interface{}, privateClaims ...interface{
}

// MustMakeSigner creates a Signer and panics if an error occurs
func MustMakeSigner(alg jose.SignatureAlgorithm, k interface{}) jose.Signer {
func MustMakeSigner(alg jose.SignatureAlgorithm, k any) jose.Signer {
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: k}, nil)
if err != nil {
panic("failed to create signer:" + err.Error())
Expand Down

0 comments on commit 1f475dd

Please sign in to comment.