Skip to content
This repository has been archived by the owner on Jul 12, 2023. It is now read-only.

Commit

Permalink
Fix rate limiting (#221)
Browse files Browse the repository at this point in the history
* Fix rate limiting

- Download go modules in dedicated step for faster local builds
- Update go-limiter to fix a redis TTL issue
- Update adminapi server config to use shared ratelimit config
- Rate limit by X-Forwarded-For, ensure limiting is applied in correct order

* Tidy mod
  • Loading branch information
sethvargo committed Aug 12, 2020
1 parent 6a8b2c6 commit 3e74671
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 53 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ ENV GOOS=linux
ENV GOARCH=amd64

WORKDIR /src
COPY . .
COPY go.mod .
COPY go.sum .
RUN go mod download

COPY . .
RUN go build \
-trimpath \
-ldflags "-s -w -extldflags '-static'" \
Expand Down
51 changes: 33 additions & 18 deletions cmd/adminapi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import (
"fmt"
"net/http"
"os"
"time"
"strings"

"github.com/google/exposure-notifications-verification-server/pkg/config"
"github.com/google/exposure-notifications-verification-server/pkg/controller"
"github.com/google/exposure-notifications-verification-server/pkg/controller/issueapi"
"github.com/google/exposure-notifications-verification-server/pkg/controller/middleware"
"github.com/google/exposure-notifications-verification-server/pkg/database"
"github.com/google/exposure-notifications-verification-server/pkg/ratelimit"
"github.com/google/exposure-notifications-verification-server/pkg/render"

"github.com/google/exposure-notifications-server/pkg/cache"
Expand All @@ -39,7 +40,6 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/sethvargo/go-limiter/httplimit"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-signalcontext"
)

Expand Down Expand Up @@ -89,21 +89,18 @@ func realMain(ctx context.Context) error {
// Create the router
r := mux.NewRouter()

// Setup rate limiter
store, err := memorystore.New(&memorystore.Config{
Tokens: config.RateLimit,
Interval: 1 * time.Minute,
})
// Rate limiting
limiterStore, err := ratelimit.RateLimiterFor(ctx, &config.RateLimit)
if err != nil {
return fmt.Errorf("failed to create limiter: %w", err)
}
defer store.Close()
defer limiterStore.Close()

httplimiter, err := httplimit.NewMiddleware(store, apiKeyFunc())
httplimiter, err := httplimit.NewMiddleware(limiterStore, limiterFunc(ctx))
if err != nil {
return fmt.Errorf("failed to create limiter middleware: %w", err)
}
r.Use(httplimiter.Handle)
rateLimit := httplimiter.Handle

// Create the renderer
h, err := render.New(ctx, "", config.DevMode)
Expand All @@ -124,6 +121,7 @@ func realMain(ctx context.Context) error {

// Install the APIKey Auth Middleware
r.Use(requireAPIKey)
r.Use(rateLimit)

issueapiController := issueapi.New(ctx, config, db, h)
r.Handle("/api/issue", issueapiController.HandleIssue()).Methods("POST")
Expand All @@ -136,17 +134,34 @@ func realMain(ctx context.Context) error {
return srv.ServeHTTPHandler(ctx, handlers.CombinedLoggingHandler(os.Stdout, r))
}

func apiKeyFunc() httplimit.KeyFunc {
ipKeyFunc := httplimit.IPKeyFunc("X-Forwarded-For")
// limiterFunc is a custom rate limiter function. It limits by realm (by API
// key, if one exists, then by IP.
func limiterFunc(ctx context.Context) httplimit.KeyFunc {
logger := logging.FromContext(ctx).Named("ratelimit")

return func(r *http.Request) (string, error) {
v := r.Header.Get("X-API-Key")
if v != "" {
dig := sha1.Sum([]byte(v))
return fmt.Sprintf("%x", dig), nil
ctx := r.Context()

// See if a user exists on the context
authApp := controller.AuthorizedAppFromContext(ctx)
if authApp != nil && authApp.RealmID != 0 {
logger.Debugw("limiting by authApp realm", "authApp", authApp.ID)
dig := sha1.Sum([]byte(fmt.Sprintf("%d", authApp.RealmID)))
return fmt.Sprintf("adminapi:realm:%x", dig), nil
}

// Get the remote addr
ip := r.RemoteAddr

// Check if x-forwarded-for exists, the load balancer sets this, and the
// first entry is the real client IP
xff := r.Header.Get("x-forwarded-for")
if xff != "" {
ip = strings.Split(xff, ",")[0]
}

// If no API key was provided, default to limiting by IP.
return ipKeyFunc(r)
logger.Debugw("limiting by ip", "ip", ip)
dig := sha1.Sum([]byte(ip))
return fmt.Sprintf("adminapi:ip:%x", dig), nil
}
}
47 changes: 35 additions & 12 deletions cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package main

import (
"context"
"crypto/sha1"
"fmt"
"net/http"
"os"
"strconv"
"strings"

"github.com/google/exposure-notifications-verification-server/pkg/config"
"github.com/google/exposure-notifications-verification-server/pkg/controller"
Expand Down Expand Up @@ -97,18 +98,18 @@ func realMain(ctx context.Context) error {
// Create the router
r := mux.NewRouter()

// Setup rate limiter
store, err := ratelimit.RateLimiterFor(ctx, &config.RateLimit)
// Rate limiting
limiterStore, err := ratelimit.RateLimiterFor(ctx, &config.RateLimit)
if err != nil {
return fmt.Errorf("failed to create limiter: %w", err)
}
defer store.Close()
defer limiterStore.Close()

httplimiter, err := httplimit.NewMiddleware(store, apiKeyFunc(db))
httplimiter, err := httplimit.NewMiddleware(limiterStore, limiterFunc(ctx, db))
if err != nil {
return fmt.Errorf("failed to create limiter middleware: %w", err)
}
r.Use(httplimiter.Handle)
rateLimit := httplimiter.Handle

// Create the renderer
h, err := render.New(ctx, "", config.DevMode)
Expand All @@ -127,6 +128,10 @@ func realMain(ctx context.Context) error {
database.APIUserTypeDevice,
})

// Install the rate limiting first. In this case, we want to limit by key
// first to reduce the chance of a database lookup.
r.Use(rateLimit)

// Install the APIKey Auth Middleware
r.Use(requireAPIKey)

Expand Down Expand Up @@ -155,27 +160,45 @@ func realMain(ctx context.Context) error {
return srv.ServeHTTPHandler(ctx, handlers.CombinedLoggingHandler(os.Stdout, r))
}

func apiKeyFunc(db *database.Database) httplimit.KeyFunc {
ipKeyFunc := httplimit.IPKeyFunc("X-Forwarded-For")
// limiterFunc is a custom rate limiter function. It limits by API key realm, if
// one exists, then by IP.
func limiterFunc(ctx context.Context, db *database.Database) httplimit.KeyFunc {
logger := logging.FromContext(ctx).Named("ratelimit")

return func(r *http.Request) (string, error) {
// Procss the API key
v := r.Header.Get("X-API-Key")
if v != "" {
// v2 API keys encode the realm
_, realmID, err := db.VerifyAPIKeySignature(v)
if err == nil {
return strconv.FormatUint(uint64(realmID), 10), nil
logger.Debugw("limiting by api key v2 realm", "realm", realmID)
dig := sha1.Sum([]byte(fmt.Sprintf("%d", realmID)))
return fmt.Sprintf("apiserver:realm:%x", dig), nil
}

// v1 API keys do not, fallback to the database
app, err := db.FindAuthorizedAppByAPIKey(v)
if err == nil && app != nil {
return strconv.FormatUint(uint64(app.RealmID), 10), nil
logger.Debugw("limiting by api key v1 realm", "realm", app.RealmID)
dig := sha1.Sum([]byte(fmt.Sprintf("%d", app.RealmID)))
return fmt.Sprintf("apiserver:realm:%x", dig), nil
}
}

// If no API key was provided, default to limiting by IP.
return ipKeyFunc(r)
// Get the remote addr
ip := r.RemoteAddr

// Check if x-forwarded-for exists, the load balancer sets this, and the
// first entry is the real client IP
xff := r.Header.Get("x-forwarded-for")
if xff != "" {
ip = strings.Split(xff, ",")[0]
}

logger.Debugw("limiting by ip", "ip", ip)
dig := sha1.Sum([]byte(ip))
return fmt.Sprintf("apiserver:ip:%x", dig), nil
}
}

Expand Down
64 changes: 46 additions & 18 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"net/http"
"os"
"strings"

"github.com/google/exposure-notifications-verification-server/pkg/config"
"github.com/google/exposure-notifications-verification-server/pkg/controller"
Expand Down Expand Up @@ -122,18 +123,17 @@ func realMain(ctx context.Context) error {
return fmt.Errorf("failed to create renderer: %w", err)
}

// Setup rate limiting
store, err := ratelimit.RateLimiterFor(ctx, &config.RateLimit)
// Rate limiting
limiterStore, err := ratelimit.RateLimiterFor(ctx, &config.RateLimit)
if err != nil {
return fmt.Errorf("failed to create limiter: %w", err)
}
defer store.Close()
defer limiterStore.Close()

httplimiter, err := httplimit.NewMiddleware(store, userEmailKeyFunc())
httplimiter, err := httplimit.NewMiddleware(limiterStore, limiterFunc(ctx))
if err != nil {
return fmt.Errorf("failed to create limiter middleware: %w", err)
}
r.Use(httplimiter.Handle)

// Install the CSRF protection middleware.
configureCSRF := middleware.ConfigureCSRF(ctx, config, h)
Expand All @@ -147,20 +147,26 @@ func realMain(ctx context.Context) error {
requireAuth := middleware.RequireAuth(ctx, auth, db, h, config.SessionDuration)
requireAdmin := middleware.RequireRealmAdmin(ctx, h)
requireRealm := middleware.RequireRealm(ctx, db, h)
rateLimit := httplimiter.Handle

// Install the handlers that don't require authentication first on the main router.
indexController := index.New(ctx, config, h)
r.Handle("/", indexController.HandleIndex()).Methods("GET")
r.Handle("/healthz", controller.HandleHealthz(ctx, h, &config.Database)).Methods("GET")
{
sub := r.PathPrefix("").Subrouter()
sub.Use(rateLimit)

indexController := index.New(ctx, config, h)
sub.Handle("/", indexController.HandleIndex()).Methods("GET")
sub.Handle("/healthz", controller.HandleHealthz(ctx, h, &config.Database)).Methods("GET")

// Session handling
sessionController := session.New(ctx, auth, config, db, h)
r.Handle("/signout", sessionController.HandleDelete()).Methods("GET")
r.Handle("/session", sessionController.HandleCreate()).Methods("POST")
// Session handling
sessionController := session.New(ctx, auth, config, db, h)
sub.Handle("/signout", sessionController.HandleDelete()).Methods("GET")
sub.Handle("/session", sessionController.HandleCreate()).Methods("POST")
}

{
sub := r.PathPrefix("/realm").Subrouter()
sub.Use(requireAuth)
sub.Use(rateLimit)

// Realms - list and select.
realmController := realm.New(ctx, config, db, h)
Expand All @@ -172,6 +178,7 @@ func realMain(ctx context.Context) error {
sub := r.PathPrefix("/home").Subrouter()
sub.Use(requireAuth)
sub.Use(requireRealm)
sub.Use(rateLimit)

homeController := home.New(ctx, config, db, h)
sub.Handle("", homeController.HandleHome()).Methods("GET")
Expand All @@ -187,6 +194,7 @@ func realMain(ctx context.Context) error {
sub.Use(requireAuth)
sub.Use(requireRealm)
sub.Use(requireAdmin)
sub.Use(rateLimit)

apikeyController := apikey.New(ctx, config, db, h)
sub.Handle("", apikeyController.HandleIndex()).Methods("GET")
Expand All @@ -205,6 +213,7 @@ func realMain(ctx context.Context) error {
userSub.Use(requireAuth)
userSub.Use(requireRealm)
userSub.Use(requireAdmin)
userSub.Use(rateLimit)

userController := user.New(ctx, config, db, h)
userSub.Handle("", userController.HandleIndex()).Methods("GET")
Expand All @@ -218,6 +227,7 @@ func realMain(ctx context.Context) error {
realmSub.Use(requireAuth)
realmSub.Use(requireRealm)
realmSub.Use(requireAdmin)
realmSub.Use(rateLimit)

realmadminController := realmadmin.New(ctx, config, db, h)
realmSub.Handle("", realmadminController.HandleIndex()).Methods("GET")
Expand All @@ -238,16 +248,34 @@ func realMain(ctx context.Context) error {
return srv.ServeHTTPHandler(ctx, handlers.CombinedLoggingHandler(os.Stdout, mux))
}

func userEmailKeyFunc() httplimit.KeyFunc {
ipKeyFunc := httplimit.IPKeyFunc("X-Forwarded-For")
// limiterFunc is a custom rate limiter function. It limits by user, if one
// exists, then by IP.
func limiterFunc(ctx context.Context) httplimit.KeyFunc {
logger := logging.FromContext(ctx).Named("ratelimit")

return func(r *http.Request) (string, error) {
user := controller.UserFromContext(r.Context())
ctx := r.Context()

// See if a user exists on the context
user := controller.UserFromContext(ctx)
if user != nil && user.Email != "" {
logger.Debugw("limiting by user", "user", user.ID)
dig := sha1.Sum([]byte(user.Email))
return fmt.Sprintf("%x", dig), nil
return fmt.Sprintf("server:user:%x", dig), nil
}

// Get the remote addr
ip := r.RemoteAddr

// Check if x-forwarded-for exists, the load balancer sets this, and the
// first entry is the real client IP
xff := r.Header.Get("x-forwarded-for")
if xff != "" {
ip = strings.Split(xff, ",")[0]
}

return ipKeyFunc(r)
logger.Debugw("limiting by ip", "ip", ip)
dig := sha1.Sum([]byte(ip))
return fmt.Sprintf("server:ip:%x", dig), nil
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ require (
github.com/prometheus/statsd_exporter v0.17.0 // indirect
github.com/sethvargo/go-envconfig v0.3.0
github.com/sethvargo/go-gcpkms v0.1.0
github.com/sethvargo/go-limiter v0.3.0
github.com/sethvargo/go-limiter v0.3.1
github.com/sethvargo/go-retry v0.1.0
github.com/sethvargo/go-signalcontext v0.1.0
github.com/stretchr/objx v0.3.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1198,8 +1198,8 @@ github.com/sethvargo/go-envconfig v0.3.0 h1:9xW3N/jvX6TkJzY99pW4WPq8tMYQElwWZinf
github.com/sethvargo/go-envconfig v0.3.0/go.mod h1:XZ2JRR7vhlBEO5zMmOpLgUhgYltqYqq4d4tKagtPUv0=
github.com/sethvargo/go-gcpkms v0.1.0 h1:pyjDLqLwpk9pMjDSTilPpaUjgP1AfSjX9WGzitZwGUY=
github.com/sethvargo/go-gcpkms v0.1.0/go.mod h1:33BuvqUjsYk0bpMgn+WCclCYtMLOyaqtn5j0fCo4vvk=
github.com/sethvargo/go-limiter v0.3.0 h1:yRMc+Qs2yqw6YJp6UxrO2iUs6DOSq4zcnljbB7/rMns=
github.com/sethvargo/go-limiter v0.3.0/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/sethvargo/go-limiter v0.3.1 h1:/FFoChDmuu+bN9DCs4k5SpW+edBcM/eIELpgNftbI4E=
github.com/sethvargo/go-limiter v0.3.1/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/sethvargo/go-retry v0.1.0 h1:8sPqlWannzcReEcYjHSNw9becsiYudcwTD7CasGjQaI=
github.com/sethvargo/go-retry v0.1.0/go.mod h1:JzIOdZqQDNpPkQDmcqgtteAcxFLtYpNF/zJCM1ysDg8=
github.com/sethvargo/go-signalcontext v0.1.0 h1:3IU7HOlmRXF0PSDf85C4nJ/zjYDjF+DS+LufcKfLvyk=
Expand Down
5 changes: 4 additions & 1 deletion pkg/config/admin_server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"github.com/google/exposure-notifications-verification-server/pkg/database"
"github.com/google/exposure-notifications-verification-server/pkg/ratelimit"

"github.com/google/exposure-notifications-server/pkg/observability"

Expand All @@ -36,8 +37,10 @@ type AdminAPIServerConfig struct {
// production environments.
DevMode bool `env:"DEV_MODE"`

// Rate limiting configuration
RateLimit ratelimit.Config

Port string `env:"PORT,default=8080"`
RateLimit uint64 `env:"RATE_LIMIT,default=60"`
APIKeyCacheDuration time.Duration `env:"API_KEY_CACHE_DURATION,default=5m"`

CodeDuration time.Duration `env:"CODE_DURATION,default=1h"`
Expand Down

0 comments on commit 3e74671

Please sign in to comment.