Skip to content

Commit

Permalink
techschool#24 middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Bastien DUMONT committed Nov 6, 2024
1 parent 234b69a commit 63c7336
Show file tree
Hide file tree
Showing 15 changed files with 307 additions and 40 deletions.
2 changes: 2 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Varaiales for local http calls
HTTP_ACCESS_TOKEN=v2.local.Jo_5B_CEfcV4OhjsFHGECiufNE65S2vyKAF0mcLtyUYKNTS24Kb1_DIDqGDHNQAkxgKrUDd3cSCRD8tUKlWSQ6TU7U1tKWpf_SFaP90-MK8HS7T0CIMKesc0WdSodrrmjNpj5xOIx36xeidckV_3qyQwDABVqUovPsrTAZhL6AnueTx2sCKZTE0k3hwQQvamywkuc0IuQ2CnebzqlH7sBJKvq9sqo2GQ3iEGW_XD9mtLsCESiaR-oJsYraa1M_Ch8HRSqZT7GGu0rpHlWVOSIprr4g.bnVsbA
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
15 changes: 13 additions & 2 deletions api/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package api

import (
"database/sql"
"errors"
"net/http"

db "github.com/bastiendmt/simplebank/db/sqlc"
"github.com/bastiendmt/simplebank/token"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
)

type createAccountRequest struct {
Owner string `json:"owner" binding:"required"`
Currency string `json:"currency" binding:"required,currency"`
}

Expand All @@ -21,8 +22,9 @@ func (server *Server) createAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.CreateAccountParams{
Owner: req.Owner,
Owner: authPayload.Username,
Currency: req.Currency,
Balance: 0,
}
Expand Down Expand Up @@ -64,6 +66,13 @@ func (server *Server) getAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
if account.Owner != authPayload.Username {
err := errors.New("account doesn't belong to the authenticated user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.JSON(http.StatusOK, account)
}

Expand All @@ -79,7 +88,9 @@ func (server *Server) listAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.ListAccountsParams{
Owner: authPayload.Username,
Limit: req.PageSize,
Offset: (req.PageID - 1) * req.PageSize,
}
Expand Down
17 changes: 17 additions & 0 deletions api/account.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
### create account
POST http://localhost:8080/accounts
Content-Type: application/json
Authorization: Bearer {{$dotenv HTTP_ACCESS_TOKEN}}

{
"owner": "John Doe",
"currency": "EUR"
}

### get account
GET http://localhost:8080/accounts/120
Authorization: Bearer {{$dotenv HTTP_ACCESS_TOKEN}}

### list accounts
GET http://localhost:8080/accounts?page_id=1&page_size=5
Authorization: Bearer {{$dotenv HTTP_ACCESS_TOKEN}}
48 changes: 45 additions & 3 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,33 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

mockdb "github.com/bastiendmt/simplebank/db/mock"
db "github.com/bastiendmt/simplebank/db/sqlc"
"github.com/bastiendmt/simplebank/token"
"github.com/bastiendmt/simplebank/util"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)

func TestGetAccountAPI(t *testing.T) {
account := randomAccount()
user, _ := randomUser(t)
account := randomAccount(user.Username)

testCases := []struct {
name string
accountId int64
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
buildStubs func(store *mockdb.MockStore)
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
name: "OK",
accountId: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, "", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(account, nil)
},
Expand All @@ -37,9 +44,37 @@ func TestGetAccountAPI(t *testing.T) {
requireBodyMatchAccount(t, recorder.Body, account)
},
},
{
name: "Unauthorized",
accountId: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorizedUser", "", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(account, nil)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "NoAuthorization",
accountId: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "NotFound",
accountId: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, "", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(db.Account{}, sql.ErrNoRows)
},
Expand All @@ -50,6 +85,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InternalError",
accountId: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, "", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Eq(account.ID)).Times(1).Return(db.Account{}, sql.ErrConnDone)
},
Expand All @@ -60,6 +98,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InvalidID",
accountId: 0,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, "", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().GetAccount(gomock.Any(), gomock.Any()).Times(0)
},
Expand Down Expand Up @@ -87,16 +128,17 @@ func TestGetAccountAPI(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}

func randomAccount() db.Account {
func randomAccount(owner string) db.Account {
return db.Account{
ID: util.RandomInt(1, 1000),
Owner: util.RandomOwner(),
Owner: owner,
Balance: util.RandomMoney(),
Currency: util.RandomCurrency(),
}
Expand Down
54 changes: 54 additions & 0 deletions api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package api

import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/bastiendmt/simplebank/token"
"github.com/gin-gonic/gin"
)

const (
authorizationHeaderKey = "authorization"
authorizationTypeBearer = "bearer"
authorizationPayloadKey = "access_payload"
)

// AuthMiddleware creates a gin middleware for authorization
func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
return func(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)

if len(authorizationHeader) == 0 {
err := errors.New("authorization header is not provided")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

fields := strings.Fields(authorizationHeader)
if len(fields) < 2 {
err := errors.New("invalid authorization header format")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

authorizationType := strings.ToLower(fields[0])
if authorizationType != authorizationTypeBearer {
err := fmt.Errorf("unsupported authorization type %s", authorizationType)
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

accessToken := fields[1]
payload, err := tokenMaker.VerifyToken(accessToken)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.Set(authorizationPayloadKey, payload)
ctx.Next()
}
}
110 changes: 110 additions & 0 deletions api/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package api

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/bastiendmt/simplebank/token"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)

func addAuthorization(
t *testing.T,
request *http.Request,
tokenMaker token.Maker,
authorizationType string,
username string,
role string,
duration time.Duration,
) {
token, payload, err := tokenMaker.CreateToken(username, duration)
require.NoError(t, err)
require.NotEmpty(t, payload)

authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
request.Header.Set(authorizationHeaderKey, authorizationHeader)
}

func TestAuthMiddleware(t *testing.T) {
username := "username"
role := ""

testCases := []struct {
name string
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
name: "OK",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
},
},
{
name: "NoAuthorization",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "UnsupportedAuthorization",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, "unsupported", username, role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "InvalidAuthorizationFormat",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, "", username, role, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "ExpiredToken",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, role, -time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
}

for i := range testCases {
tc := testCases[i]

t.Run(tc.name, func(t *testing.T) {
server := newTestServer(t, nil)
authPath := "/auth"
server.router.GET(
authPath,
authMiddleware(server.tokenMaker),
func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{})
},
)

recorder := httptest.NewRecorder()
request, err := http.NewRequest(http.MethodGet, authPath, nil)
require.NoError(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}
10 changes: 6 additions & 4 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ func (server *Server) setupRouter() {
router.POST("/users", server.createUser)
router.POST("/users/login", server.loginUser)

router.POST("/accounts", server.createAccount)
router.GET("/accounts/:id", server.getAccount)
router.GET("/accounts", server.listAccount)
authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker))

router.POST("/transfers", server.createTransfer)
authRoutes.POST("/accounts", server.createAccount)
authRoutes.GET("/accounts/:id", server.getAccount)
authRoutes.GET("/accounts", server.listAccount)

authRoutes.POST("/transfers", server.createTransfer)

server.router = router
}
Expand Down
Loading

0 comments on commit 63c7336

Please sign in to comment.