Skip to content
This repository has been archived by the owner on Feb 10, 2025. It is now read-only.

Commit

Permalink
Automatic consumer recovery + sqs/sns tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Filip committed Mar 16, 2017
1 parent 1402f32 commit d0adb06
Show file tree
Hide file tree
Showing 9 changed files with 753 additions and 24 deletions.
5 changes: 5 additions & 0 deletions forwarder/forwarder.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package forwarder

const (
// EmptyMessageError empty error message
EmptyMessageError = "message is empty"
)

// Client interface to forwarding messages
type Client interface {
Name() string
Expand Down
20 changes: 13 additions & 7 deletions rabbitmq/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,37 @@ func (c Consumer) Start(forwarder forwarder.Client, check chan bool, stop chan b
return failOnError(err, "Failed to register a consumer")
}
params := workerParams{forwarder, msgs, check, stop, conn, ch}
go c.push(params)
go c.startForwarding(&params)

return nil
}

func (c Consumer) push(params workerParams) {
func (c Consumer) startForwarding(params *workerParams) {
forwarderName := params.forwarder.Name()
defer params.ch.Close()
defer params.conn.Close()
log.Printf("[%s] Started forwarding messages to %s", c.Name(), forwarderName)
for {
select {
case d := <-params.msgs:
case d, ok := <-params.msgs:
if !ok { // channel already closed
go c.Start(params.forwarder, params.check, params.stop)
return
}
log.Printf("[%s] Message to forward: %v", c.Name(), d.MessageId)
err := params.forwarder.Push(string(d.Body))
if err != nil {
log.Printf("[%s] Could not forward message. Error: %s", forwarderName, err.Error())
} else {
d.Ack(true)
if err := d.Ack(true); err != nil {
log.Println("Could not ack message with id:", d.MessageId)
}
}
case <-params.check:
log.Printf("[%s] Checking", forwarderName)
case <-params.stop:
log.Printf("[%s] Closing", forwarderName)
params.ch.Close()
params.conn.Close()
return
break
}
}
}
Expand Down
23 changes: 14 additions & 9 deletions sns/forwader.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package sns

import (
"errors"
"log"

"github.com/AirHelp/rabbit-amazon-forwarder/config"
"github.com/AirHelp/rabbit-amazon-forwarder/forwarder"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sns/snsiface"
)

const (
Expand All @@ -16,13 +18,18 @@ const (

type Forwarder struct {
name string
snsClient *sns.SNS
snsClient snsiface.SNSAPI
topic string
}

// CreateForwarder creates instance of forwarder
func CreateForwarder(entry config.AmazonEntry) forwarder.Client {
client := awsClient()
func CreateForwarder(entry config.AmazonEntry, snsClient ...snsiface.SNSAPI) forwarder.Client {
var client snsiface.SNSAPI
if len(snsClient) > 0 {
client = snsClient[0]
} else {
client = sns.New(session.Must(session.NewSession()))
}
forwarder := Forwarder{entry.Name, client, entry.Target}
log.Print("Created forwarder: ", forwarder.Name())
return forwarder
Expand All @@ -35,6 +42,9 @@ func (f Forwarder) Name() string {

// Push pushes message to forwarding infrastructure
func (f Forwarder) Push(message string) error {
if message == "" {
return errors.New(forwarder.EmptyMessageError)
}
params := &sns.PublishInput{
Message: aws.String(message),
TargetArn: aws.String(f.topic),
Expand All @@ -45,11 +55,6 @@ func (f Forwarder) Push(message string) error {
log.Printf("[%s] Could not forward message. Error: %s", f.Name(), err.Error())
return err
}
log.Printf("[%s] Forward succeeded. Response: %s", f.Name(), resp)
log.Printf("[%s] Forward succeeded. Response: %v", f.Name(), resp)
return nil
}

func awsClient() *sns.SNS {
sess := session.New()
return sns.New(sess)
}
79 changes: 79 additions & 0 deletions sns/forwarder_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package sns

import (
"errors"
"testing"

"github.com/AirHelp/rabbit-amazon-forwarder/config"
"github.com/AirHelp/rabbit-amazon-forwarder/forwarder"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sns/snsiface"
)

var badRequest = "Bad request"

func TestCreateForwarder(t *testing.T) {
entry := config.AmazonEntry{Type: "SNS",
Name: "sns-test",
Expand All @@ -16,3 +23,75 @@ func TestCreateForwarder(t *testing.T) {
t.Errorf("wrong forwarder name, expected:%s, found: %s", entry.Name, forwarder.Name())
}
}

func TestPush(t *testing.T) {
topicName := "topic1"
entry := config.AmazonEntry{Type: "SNS",
Name: "sns-test",
Target: topicName,
}
scenarios := []struct {
name string
mock snsiface.SNSAPI
message string
topic string
err error
}{
{
name: "empty message",
mock: mockAmazonSNS{resp: sns.PublishOutput{MessageId: aws.String("messageId")}, topic: topicName, message: ""},
message: "",
topic: topicName,
err: errors.New(forwarder.EmptyMessageError),
},
{
name: "bad request",
mock: mockAmazonSNS{resp: sns.PublishOutput{MessageId: aws.String("messageId")}, topic: topicName, message: badRequest},
message: badRequest,
topic: topicName,
err: errors.New(badRequest),
},
{
name: "success",
mock: mockAmazonSNS{resp: sns.PublishOutput{MessageId: aws.String("messageId")}, topic: topicName, message: "abc"},
message: "abc",
topic: topicName,
err: nil,
},
}
for _, scenario := range scenarios {
t.Log("Scenario name: ", scenario.name)
forwarder := CreateForwarder(entry, scenario.mock)
err := forwarder.Push(scenario.message)
if scenario.err == nil && err != nil {
t.Errorf("Error should not occur")
return
}
if scenario.err == err {
return
}
if err.Error() != scenario.err.Error() {
t.Errorf("Wrong error, expecting:%v, got:%v", scenario.err, err)
}
}
}

type mockAmazonSNS struct {
snsiface.SNSAPI
resp sns.PublishOutput
topic string
message string
}

func (m mockAmazonSNS) Publish(input *sns.PublishInput) (*sns.PublishOutput, error) {
if *input.TargetArn != m.topic {
return nil, errors.New("Wrong topic name")
}
if *input.Message != m.message {
return nil, errors.New("Wrong message body")
}
if *input.Message == badRequest {
return nil, errors.New(badRequest)
}
return &m.resp, nil
}
21 changes: 13 additions & 8 deletions sqs/forwader.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package sqs

import (
"errors"
"log"

"github.com/AirHelp/rabbit-amazon-forwarder/config"
"github.com/AirHelp/rabbit-amazon-forwarder/forwarder"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
)

const (
Expand All @@ -16,13 +18,18 @@ const (

type Forwarder struct {
name string
sqsClient *sqs.SQS
sqsClient sqsiface.SQSAPI
queue string
}

// CreateForwarder creates instance of forwarder
func CreateForwarder(entry config.AmazonEntry) forwarder.Client {
client := awsClient()
func CreateForwarder(entry config.AmazonEntry, sqsClient ...sqsiface.SQSAPI) forwarder.Client {
var client sqsiface.SQSAPI
if len(sqsClient) > 0 {
client = sqsClient[0]
} else {
client = sqs.New(session.Must(session.NewSession()))
}
forwarder := Forwarder{entry.Name, client, entry.Target}
log.Print("Created forwarder: ", forwarder.Name())
return forwarder
Expand All @@ -35,6 +42,9 @@ func (f Forwarder) Name() string {

// Push pushes message to forwarding infrastructure
func (f Forwarder) Push(message string) error {
if message == "" {
return errors.New(forwarder.EmptyMessageError)
}
params := &sqs.SendMessageInput{
MessageBody: aws.String(message), // Required
QueueUrl: aws.String(f.queue), // Required
Expand All @@ -49,8 +59,3 @@ func (f Forwarder) Push(message string) error {
log.Printf("[%s] Forward succeeded. Response: %s", f.Name(), resp)
return nil
}

func awsClient() *sqs.SQS {
sess := session.New()
return sqs.New(sess)
}
79 changes: 79 additions & 0 deletions sqs/forwarder_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package sqs

import (
"errors"
"testing"

"github.com/AirHelp/rabbit-amazon-forwarder/config"
"github.com/AirHelp/rabbit-amazon-forwarder/forwarder"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
)

var badRequest = "Bad request"

func TestCreateForwarder(t *testing.T) {
entry := config.AmazonEntry{Type: "SQS",
Name: "sqs-test",
Expand All @@ -16,3 +23,75 @@ func TestCreateForwarder(t *testing.T) {
t.Errorf("wrong forwarder name, expected:%s, found: %s", entry.Name, forwarder.Name())
}
}

func TestPush(t *testing.T) {
queueName := "queue1"
entry := config.AmazonEntry{Type: "SQS",
Name: "sqs-test",
Target: queueName,
}
scenarios := []struct {
name string
mock sqsiface.SQSAPI
message string
queue string
err error
}{
{
name: "empty message",
mock: mockAmazonSQS{resp: sqs.SendMessageOutput{MessageId: aws.String("messageId")}, queue: queueName, message: ""},
message: "",
queue: queueName,
err: errors.New(forwarder.EmptyMessageError),
},
{
name: "bad request",
mock: mockAmazonSQS{resp: sqs.SendMessageOutput{MessageId: aws.String("messageId")}, queue: queueName, message: badRequest},
message: badRequest,
queue: queueName,
err: errors.New(badRequest),
},
{
name: "success",
mock: mockAmazonSQS{resp: sqs.SendMessageOutput{MessageId: aws.String("messageId")}, queue: queueName, message: "abc"},
message: "abc",
queue: queueName,
err: nil,
},
}
for _, scenario := range scenarios {
t.Log("Scenario name: ", scenario.name)
forwarder := CreateForwarder(entry, scenario.mock)
err := forwarder.Push(scenario.message)
if scenario.err == nil && err != nil {
t.Errorf("Error should not occur")
return
}
if scenario.err == err {
return
}
if err.Error() != scenario.err.Error() {
t.Errorf("Wrong error, expecting:%v, got:%v", scenario.err, err)
}
}
}

type mockAmazonSQS struct {
sqsiface.SQSAPI
resp sqs.SendMessageOutput
queue string
message string
}

func (m mockAmazonSQS) SendMessage(input *sqs.SendMessageInput) (*sqs.SendMessageOutput, error) {
if *input.QueueUrl != m.queue {
return nil, errors.New("Wrong queue name")
}
if *input.MessageBody != m.message {
return nil, errors.New("Wrong message body")
}
if *input.MessageBody == badRequest {
return nil, errors.New(badRequest)
}
return &m.resp, nil
}
Loading

0 comments on commit d0adb06

Please sign in to comment.