Skip to content

Commit

Permalink
refactor: separate validator from rest of server code (#806)
Browse files Browse the repository at this point in the history
Co-authored-by: Zulkhair Abdullah Daim <[email protected]>
Co-authored-by: Ryan Martin <[email protected]>
Co-authored-by: Ryan Martin <[email protected]>
  • Loading branch information
4 people authored Nov 12, 2024
1 parent e17bab7 commit 25aee78
Show file tree
Hide file tree
Showing 10 changed files with 678 additions and 158 deletions.
11 changes: 10 additions & 1 deletion cardinal/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ func WithDisableSignatureVerification() WorldOption {
// This setting is ignored if the DisableSignatureVerification option is used
// NOTE: this means that the real time clock for the sender and receiver
// must be synchronized
func WithMessageExpiration(seconds int) WorldOption {
func WithMessageExpiration(seconds uint) WorldOption {
return WorldOption{
serverOption: server.WithMessageExpiration(seconds),
}
}

// WithHashCacheSize how big the cache of hashes used for replay protection
// is allowed to be. Default is 1MB.
// This setting is ignored if the DisableSignatureVerification option is used
func WithHashCacheSize(sizeKB uint) WorldOption {
return WorldOption{
serverOption: server.WithHashCacheSize(sizeKB),
}
}

// WithTickChannel sets the channel that will be used to decide when world.doTick is executed. If unset, a loop interval
// of 1 second will be set. To set some other time, use: WithTickChannel(time.Tick(<some-duration>)). Tests can pass
// in a channel controlled by the test for fine-grained control over when ticks are executed.
Expand Down
177 changes: 45 additions & 132 deletions cardinal/server/handler/tx.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,23 @@
package handler

import (
"errors"
"fmt"
"time"

"github.com/coocood/freecache"
"github.com/ethereum/go-ethereum/common"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/rotisserie/eris"

personaMsg "pkg.world.dev/world-engine/cardinal/persona/msg"
servertypes "pkg.world.dev/world-engine/cardinal/server/types"
"pkg.world.dev/world-engine/cardinal/server/validator"
"pkg.world.dev/world-engine/cardinal/types"
"pkg.world.dev/world-engine/sign"
)

const cacheRetentionExtraSeconds = 10 // this is how many seconds past normal expiration a hash is left in the cache.
// we want to ensure it's long enough that any message that's not expired but
// still has its hash in the cache for replay protection. Setting it too long
// would cause the cache to be bigger than necessary

var (
ErrNoPersonaTag = errors.New("persona tag is required")
ErrWrongNamespace = errors.New("incorrect namespace")
ErrSystemTransactionRequired = errors.New("system transaction required")
ErrSystemTransactionForbidden = errors.New("system transaction forbidden")
)

// PostTransactionResponse is the HTTP response for a successful transaction submission
type PostTransactionResponse struct {
TxHash string
Tick uint64
}

type SignatureVerification struct {
IsDisabled bool
MessageExpirationSeconds int
HashCacheSizeKB int
Cache *freecache.Cache
}

type Transaction = sign.Transaction

// PostTransaction godoc
//
// @Summary Submits a transaction
Expand All @@ -52,14 +26,14 @@ type Transaction = sign.Transaction
// @Produce application/json
// @Param txGroup path string true "Message group"
// @Param txName path string true "Name of a registered message"
// @Param txBody body Transaction true "Transaction details & message to be submitted"
// @Param txBody body sign.Transaction true "Transaction details & message to be submitted"
// @Success 200 {object} PostTransactionResponse "Transaction hash and tick"
// @Failure 400 {string} string "Invalid request parameter"
// @Failure 403 {string} string "Forbidden"
// @Failure 408 {string} string "Request Timeout - message expired"
// @Router /tx/{txGroup}/{txName} [post]
func PostTransaction(
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, verify SignatureVerification,
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, validator *validator.SignatureValidator,
) func(*fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
msgType, ok := msgs[ctx.Params("group")][ctx.Params("name")]
Expand All @@ -69,15 +43,14 @@ func PostTransaction(
}

// extract the transaction from the fiber context
tx, fiberErr := extractTx(ctx, verify)
if fiberErr != nil {
return fiberErr
tx, err := extractTx(ctx, validator)
if err != nil {
return err
}

// Validate the transaction
if err := validateTx(tx); err != nil {
log.Errorf("message %s has invalid transaction payload: %v", tx.Hash.String(), err)
return fiber.NewError(fiber.StatusBadRequest, "Bad Request - invalid payload")
// make sure the transaction hasn't expired
if err = validator.ValidateTransactionTTL(tx); err != nil {
return httpResultFromError(err, false)
}

// Decode the message from the transaction
Expand All @@ -87,31 +60,19 @@ func PostTransaction(
return fiber.NewError(fiber.StatusBadRequest, "Bad Request - failed to decode tx message")
}

// check the signature
if !verify.IsDisabled {
var signerAddress string
if msgType.Name() == personaMsg.CreatePersonaMessageName {
// don't need to check the cast bc we already validated this above
createPersonaMsg, _ := msg.(personaMsg.CreatePersona)
signerAddress = createPersonaMsg.SignerAddress
}

if err = lookupSignerAndValidateSignature(world, signerAddress, tx); err != nil {
log.Errorf("Signature validation failed for message %s: %v", tx.Hash.String(), err)
return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized - invalid signature")
// there's a special case for the CreatePersona message
var signerAddress string
if msgType.Name() == personaMsg.CreatePersonaMessageName {
createPersonaMsg, ok := msg.(personaMsg.CreatePersona)
if !ok {
return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error - bad message type")
}
signerAddress = createPersonaMsg.SignerAddress
}

// the message was valid, so add its hash to the cache
// we don't do this until we have verified the signature to prevent an attack where someone sends
// large numbers of hashes with unsigned/invalid messages and thus blocks legit messages from
// being handled
err = verify.Cache.Set(tx.Hash.Bytes(), nil, verify.MessageExpirationSeconds+cacheRetentionExtraSeconds)
if err != nil {
// if we couldn't store the hash in the cache, don't process the transaction, since that
// would open us up to replay attacks
log.Errorf("unexpected cache store error %v. message %s ignored", err, tx.Hash.String())
return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error - cache store")
}
// Validate the transaction's signature
if err = validator.ValidateTransactionSignature(tx, signerAddress); err != nil {
return httpResultFromError(err, true)
}

// Add the transaction to the engine
Expand All @@ -133,16 +94,16 @@ func PostTransaction(
// @Accept application/json
// @Produce application/json
// @Param txName path string true "Name of a registered message"
// @Param txBody body Transaction true "Transaction details & message to be submitted"
// @Param txBody body sign.Transaction true "Transaction details & message to be submitted"
// @Success 200 {object} PostTransactionResponse "Transaction hash and tick"
// @Failure 400 {string} string "Invalid request parameter"
// @Failure 403 {string} string "Forbidden"
// @Failure 408 {string} string "Request Timeout - message expired"
// @Router /tx/game/{txName} [post]
func PostGameTransaction(
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, verify SignatureVerification,
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, validator *validator.SignatureValidator,
) func(*fiber.Ctx) error {
return PostTransaction(world, msgs, verify)
return PostTransaction(world, msgs, validator)
}

// NOTE: duplication for cleaner swagger docs
Expand All @@ -152,7 +113,7 @@ func PostGameTransaction(
// @Description Creates a persona
// @Accept application/json
// @Produce application/json
// @Param txBody body Transaction true "Transaction details & message to be submitted"
// @Param txBody body sign.Transaction true "Transaction details & message to be submitted"
// @Success 200 {object} PostTransactionResponse "Transaction hash and tick"
// @Failure 400 {string} string "Invalid request parameter"
// @Failure 401 {string} string "Unauthorized - signature was invalid"
Expand All @@ -161,31 +122,17 @@ func PostGameTransaction(
// @Failure 500 {string} string "Internal Server Error - unexpected cache errors"
// @Router /tx/persona/create-persona [post]
func PostPersonaTransaction(
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, verify SignatureVerification,
world servertypes.ProviderWorld, msgs map[string]map[string]types.Message, validator *validator.SignatureValidator,
) func(*fiber.Ctx) error {
return PostTransaction(world, msgs, verify)
}

func isHashInCache(hash common.Hash, cache *freecache.Cache) (bool, error) {
_, err := cache.Get(hash.Bytes())
if err == nil {
// found it
return true, nil
}
if errors.Is(err, freecache.ErrNotFound) {
// ignore ErrNotFound, just return false
return false, nil
}
// return all other errors
return false, err
return PostTransaction(world, msgs, validator)
}

func extractTx(ctx *fiber.Ctx, verify SignatureVerification) (*sign.Transaction, *fiber.Error) {
func extractTx(ctx *fiber.Ctx, validator *validator.SignatureValidator) (*sign.Transaction, error) {
var tx *sign.Transaction
var err error
// Parse the request body into a sign.Transaction struct tx := new(Transaction)
// this also calculates the hash
if !verify.IsDisabled {
if validator != nil && !validator.IsDisabled {
// we are doing signature verification, so use sign's Unmarshal which does extra checks
tx, err = sign.UnmarshalTransaction(ctx.Body())
} else {
Expand All @@ -195,65 +142,31 @@ func extractTx(ctx *fiber.Ctx, verify SignatureVerification) (*sign.Transaction,
}
if err != nil {
log.Errorf("body parse failed: %v", err)
return nil, fiber.NewError(fiber.StatusBadRequest, "Bad Request - unparseable body")
}
if !verify.IsDisabled {
txEarliestValidTimestamp := sign.TimestampAt(
time.Now().Add(-(time.Duration(verify.MessageExpirationSeconds) * time.Second)))
// before we even create the hash or validate the signature, check to see if the message has expired
if tx.Timestamp < txEarliestValidTimestamp {
log.Errorf("message older than %d seconds. Got timestamp: %d, current timestamp: %d ",
verify.MessageExpirationSeconds, tx.Timestamp, sign.TimestampNow())
return nil, fiber.NewError(fiber.StatusRequestTimeout, "Request Timeout - signature too old")
}
// check for duplicate message via hash cache
if found, err := isHashInCache(tx.Hash, verify.Cache); err != nil {
log.Errorf("unexpected cache error %v. message %s ignored", err, tx.Hash.String())
return nil, fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error - cache failed")
} else if found {
// if found in the cache, the message hash has already been used, so reject it
log.Errorf("message %s already handled", tx.Hash.String())
return nil, fiber.NewError(fiber.StatusForbidden, "Forbidden - duplicate message")
}
// at this point we know that the generated hash is not in the cache, so this message is not a replay
return nil, eris.Wrap(err, "Bad Request - unparseable body")
}
return tx, nil
}

func lookupSignerAndValidateSignature(world servertypes.ProviderWorld, signerAddress string, tx *Transaction) error {
var err error
if signerAddress == "" {
signerAddress, err = world.GetSignerForPersonaTag(tx.PersonaTag, 0)
if err != nil {
return fmt.Errorf("could not get signer for persona %s: %w", tx.PersonaTag, err)
}
// turns the various errors into an appropriate HTTP result
func httpResultFromError(err error, isSignatureValidation bool) error {
log.Error(err) // log the private internal details
if eris.Is(err, validator.ErrDuplicateMessage) {
return fiber.NewError(fiber.StatusForbidden, "Forbidden - duplicate message")
}
if err = validateSignature(tx, signerAddress, world.Namespace(),
tx.IsSystemTransaction()); err != nil {
return fmt.Errorf("could not validate signature for persona %s: %w", tx.PersonaTag, err)
if eris.Is(err, validator.ErrMessageExpired) {
return fiber.NewError(fiber.StatusRequestTimeout, "Request Timeout - message expired")
}
return nil
}

// validateTx validates the transaction payload
func validateTx(tx *Transaction) error {
// TODO(scott): we should use the validator package here
if tx.PersonaTag == "" {
return ErrNoPersonaTag
if eris.Is(err, validator.ErrBadTimestamp) {
return fiber.NewError(fiber.StatusBadRequest, "Bad Request - bad timestamp")
}
return nil
}

// validateSignature validates that the signature of transaction is valid
func validateSignature(tx *Transaction, signerAddr string, namespace string, systemTx bool) error {
if tx.Namespace != namespace {
return eris.Wrap(ErrWrongNamespace, fmt.Sprintf("expected %q got %q", namespace, tx.Namespace))
if eris.Is(err, validator.ErrNoPersonaTag) {
return fiber.NewError(fiber.StatusBadRequest, "Bad Request - no persona tag")
}
if systemTx && !tx.IsSystemTransaction() {
return eris.Wrap(ErrSystemTransactionRequired, "")
if eris.Is(err, validator.ErrInvalidSignature) {
return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized - signature validation failed")
}
if !systemTx && tx.IsSystemTransaction() {
return eris.Wrap(ErrSystemTransactionForbidden, "")
if isSignatureValidation {
return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error - signature validation failed")
}
return eris.Wrap(tx.Verify(signerAddr), "")
return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error - ttl validation failed")
}
10 changes: 5 additions & 5 deletions cardinal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func DisableSwagger() Option {
// DisableSignatureVerification disables signature verification.
func DisableSignatureVerification() Option {
return func(s *Server) {
s.verify.IsDisabled = true
s.config.isSignatureValidationDisabled = true
}
}

Expand All @@ -31,17 +31,17 @@ func DisableSignatureVerification() Option {
// This setting is ignored if the DisableSignatureVerification option is used
// NOTE: this means that the real time clock for the sender and receiver
// must be synchronized
func WithMessageExpiration(seconds int) Option {
func WithMessageExpiration(seconds uint) Option {
return func(s *Server) {
s.verify.MessageExpirationSeconds = seconds
s.config.messageExpirationSeconds = seconds
}
}

// WithHashCacheSize how big the cache of hashes used for replay protection
// is allowed to be. Default is 1MB.
// This setting is ignored if the DisableSignatureVerification option is used
func WithHashCacheSize(sizeKB int) Option {
func WithHashCacheSize(sizeKB uint) Option {
return func(s *Server) {
s.verify.HashCacheSizeKB = sizeKB
s.config.messageHashCacheSizeKB = sizeKB
}
}
Loading

0 comments on commit 25aee78

Please sign in to comment.