diff --git a/consumer/consumer.go b/consumer/consumer.go index 940dea7..c4cc22c 100644 --- a/consumer/consumer.go +++ b/consumer/consumer.go @@ -5,5 +5,5 @@ import "github.com/AirHelp/rabbit-amazon-forwarder/forwarder" // Client intarface for consuming messages type Client interface { Name() string - Consume(forwarder.Client) error + Start(forwarder.Client, chan bool, chan bool) error } diff --git a/healthcheck/health.go b/healthcheck/health.go deleted file mode 100644 index 81f4576..0000000 --- a/healthcheck/health.go +++ /dev/null @@ -1,13 +0,0 @@ -package health - -import "net/http" - -const ( - success = "success" -) - -// Check verifies if application is working properly -func Check(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - w.Write([]byte(success)) -} diff --git a/healthcheck/health_test.go b/healthcheck/health_test.go deleted file mode 100644 index 5b94725..0000000 --- a/healthcheck/health_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package health - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestHealthCheck(t *testing.T) { - req, err := http.NewRequest("GET", "/health", nil) - if err != nil { - t.Fatal(err) - } - - rr := httptest.NewRecorder() - handler := http.HandlerFunc(Check) - - handler.ServeHTTP(rr, req) - - if status := rr.Code; status != http.StatusOK { - t.Errorf("handler returned wrong status code: got %v want %v", - status, http.StatusOK) - } - - if body := rr.Body.String(); body != success { - t.Errorf("handler returned wrong message body: got %v want %v", - body, success) - } -} diff --git a/mapping/mapping.go b/mapping/mapping.go index 7b3da33..4afdea2 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -44,26 +44,24 @@ func New(helpers ...Helper) Client { return Client{helper} } -// LoadAndStart loads and starts mappings -func (c Client) LoadAndStart() error { +// Load loads mappings +func (c Client) Load() (map[consumer.Client]forwarder.Client, error) { + consumerForwaderMap := make(map[consumer.Client]forwarder.Client) data, err := c.loadFile() if err != nil { - return err + return consumerForwaderMap, err } var pairsList pairs if err = json.Unmarshal(data, &pairsList); err != nil { - return err + return consumerForwaderMap, err } - log.Print("Starting consumer->forwader pairs") + log.Print("Loading consumer->forwader pairs") for _, pair := range pairsList { consumer := c.helper.createConsumer(pair.Source) forwarder := c.helper.createForwarder(pair.Destination) - log.Printf("Starting consumer:%s with forwader:%s", consumer.Name(), forwarder.Name()) - if err := consumer.Consume(forwarder); err != nil { - return err - } + consumerForwaderMap[consumer] = forwarder } - return nil + return consumerForwaderMap, nil } func (c Client) loadFile() ([]byte, error) { @@ -73,7 +71,7 @@ func (c Client) loadFile() ([]byte, error) { } func (h helperImpl) createConsumer(item common.Item) consumer.Client { - log.Print("Creating consumer: ", item.Type) + log.Printf("Creating consumer: [%s, %s]", item.Type, item.Name) switch item.Type { case rabbitmq.Type: return rabbitmq.CreateConsumer(item) @@ -82,7 +80,7 @@ func (h helperImpl) createConsumer(item common.Item) consumer.Client { } func (h helperImpl) createForwarder(item common.Item) forwarder.Client { - log.Print("Creating forwarder: ", item.Type) + log.Printf("Creating forwarder: [%s, %s]", item.Type, item.Name) switch item.Type { case sns.Type: return sns.CreateForwarder(item) diff --git a/mapping/mapping_test.go b/mapping/mapping_test.go index 33afcc6..8ea16b7 100644 --- a/mapping/mapping_test.go +++ b/mapping/mapping_test.go @@ -18,12 +18,17 @@ const ( snsType = "sns" ) -func TestLoadAndStart(t *testing.T) { +func TestLoad(t *testing.T) { os.Setenv(common.MappingFile, "../tests/rabbit_to_sns.json") client := New(MockMappingHelper{}) - if err := client.LoadAndStart(); err != nil { + var consumerForwarderMap map[consumer.Client]forwarder.Client + var err error + if consumerForwarderMap, err = client.Load(); err != nil { t.Errorf("could not load mapping and start mocked rabbit->sns pair: %s", err.Error()) } + if len(consumerForwarderMap) != 1 { + t.Errorf("wrong consumerForwarderMap size, expected 1, got %d", len(consumerForwarderMap)) + } } func TestLoadFile(t *testing.T) { @@ -118,7 +123,7 @@ func (c MockRabbitConsumer) Name() string { return rabbitType } -func (c MockRabbitConsumer) Consume(forwarder.Client) error { +func (c MockRabbitConsumer) Start(client forwarder.Client, check chan bool, stop chan bool) error { return nil } diff --git a/rabbitmq/consumer.go b/rabbitmq/consumer.go index 880372f..364582c 100644 --- a/rabbitmq/consumer.go +++ b/rabbitmq/consumer.go @@ -35,19 +35,17 @@ func (c Consumer) Name() string { // TODO gracefull shotdown // Consume consumes messages from Rabbit queue -func (c Consumer) Consume(forwarder forwarder.Client) error { +func (c Consumer) Start(forwarder forwarder.Client, check chan bool, stop chan bool) error { log.Print("Starting consumer with params: ", c) conn, err := amqp.Dial(c.ConnectionURL) if err != nil { failOnError(err, "Failed to connect to RabbitMQ") } - // defer conn.Close() ch, err := conn.Channel() if err != nil { failOnError(err, "Failed to open a channel") } - // defer ch.Close() err = ch.ExchangeDeclare( c.ExchangeName, // name @@ -96,20 +94,30 @@ func (c Consumer) Consume(forwarder forwarder.Client) error { return failOnError(err, "Failed to register a consumer") } - go c.push(msgs, forwarder) + go c.push(forwarder, msgs, check, stop, conn, ch) return nil } -func (c Consumer) push(msgs <-chan amqp.Delivery, forwarder forwarder.Client) { +func (c Consumer) push(forwarder forwarder.Client, msgs <-chan amqp.Delivery, check chan bool, stop chan bool, conn *amqp.Connection, ch *amqp.Channel) { log.Printf("[%s] Started forwarding messages to %s", c.Name(), forwarder.Name()) - for d := range msgs { - log.Printf("[%s] Message to forward: %v", c.Name(), d.MessageId) - err := forwarder.Push(string(d.Body)) - if err != nil { - log.Printf("[%s] Could not forward message. Error: %s", forwarder.Name(), err.Error()) - } else { - d.Ack(true) + for { + select { + case d := <-msgs: + log.Printf("[%s] Message to forward: %v", c.Name(), d.MessageId) + err := forwarder.Push(string(d.Body)) + if err != nil { + log.Printf("[%s] Could not forward message. Error: %s", forwarder.Name(), err.Error()) + } else { + d.Ack(true) + } + case <-check: + log.Printf("[%s] Checking", forwarder.Name()) + case <-stop: + log.Printf("[%s] Closing", forwarder.Name()) + ch.Close() + conn.Close() + return } } } diff --git a/server.go b/server.go index 922a205..c2575ca 100644 --- a/server.go +++ b/server.go @@ -4,16 +4,21 @@ import ( "log" "net/http" - "github.com/AirHelp/rabbit-amazon-forwarder/healthcheck" "github.com/AirHelp/rabbit-amazon-forwarder/mapping" + "github.com/AirHelp/rabbit-amazon-forwarder/supervisor" ) func main() { - http.HandleFunc("/health", health.Check) - err := mapping.New().LoadAndStart() + consumerForwarderMap, err := mapping.New().Load() if err != nil { - log.Fatalf("Could not load and start consumer->forwader pairs. Error: " + err.Error()) + log.Fatalf("Could not load consumer->forwader pairs. Error: " + err.Error()) } + supervisor := supervisor.New(consumerForwarderMap) + if err := supervisor.Start(); err != nil { + log.Fatal("Could not start supervisor. Error: ", err.Error()) + } + http.HandleFunc("/restart", supervisor.Restart) + http.HandleFunc("/health", supervisor.Check) log.Print("Starting http server") log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/sns/forwader.go b/sns/forwader.go index f561789..305d7d0 100644 --- a/sns/forwader.go +++ b/sns/forwader.go @@ -35,7 +35,6 @@ func (f Forwarder) Name() string { // Push pushes message to forwarding infrastructure func (f Forwarder) Push(message string) error { - log.Print("Topic: ", f.topic) params := &sns.PublishInput{ Message: aws.String(message), TargetArn: aws.String(f.topic), diff --git a/sqs/forwader.go b/sqs/forwader.go index 9f18e55..b01cea2 100644 --- a/sqs/forwader.go +++ b/sqs/forwader.go @@ -35,8 +35,6 @@ func (f Forwarder) Name() string { // Push pushes message to forwarding infrastructure func (f Forwarder) Push(message string) error { - log.Print("Queue: ", f.queue) - params := &sqs.SendMessageInput{ MessageBody: aws.String(message), // Required QueueUrl: aws.String(f.queue), // Required diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go new file mode 100644 index 0000000..fd5a200 --- /dev/null +++ b/supervisor/supervisor.go @@ -0,0 +1,91 @@ +package supervisor + +import ( + "fmt" + "log" + "net/http" + "time" + + "github.com/AirHelp/rabbit-amazon-forwarder/consumer" + "github.com/AirHelp/rabbit-amazon-forwarder/forwarder" +) + +type consumerChannel struct { + name string + check chan bool + stop chan bool +} + +// Client supervisor client +type Client struct { + mappings map[consumer.Client]forwarder.Client + consumers map[string]*consumerChannel +} + +// New client for supervisor +func New(consumerForwarderMap map[consumer.Client]forwarder.Client) Client { + return Client{mappings: consumerForwarderMap} +} + +// Start starts supervisor +func (c *Client) Start() error { + c.consumers = make(map[string]*consumerChannel) + for consumer, forwarder := range c.mappings { + channel := makeConsumerChannel(forwarder.Name()) + c.consumers[forwarder.Name()] = channel + if err := consumer.Start(forwarder, channel.check, channel.stop); err != nil { + return err + } + log.Printf("Started consumer:%s with forwader:%s", consumer.Name(), forwarder.Name()) + } + return nil +} + +// Check checks running consumers +func (c *Client) Check(w http.ResponseWriter, r *http.Request) { + stopped := 0 + for _, consumer := range c.consumers { + if len(consumer.check) > 0 { + stopped = stopped + 1 + continue + } + consumer.check <- true + time.Sleep(500 * time.Millisecond) + if len(consumer.check) > 0 { + stopped = stopped + 1 + } + } + if stopped > 0 { + w.WriteHeader(500) + message := fmt.Sprintf("Number of failed consumers: %d", stopped) + w.Write([]byte(message)) + return + } + w.WriteHeader(200) + w.Write([]byte("success")) +} + +// Restart restarts every consumer +func (c *Client) Restart(w http.ResponseWriter, r *http.Request) { + c.stop() + if err := c.Start(); err != nil { + w.WriteHeader(500) + w.Write([]byte(err.Error())) + return + } + w.WriteHeader(200) + w.Write([]byte("success")) +} + +func (c *Client) stop() { + for _, consumer := range c.consumers { + consumer.stop <- true + } + +} + +func makeConsumerChannel(name string) *consumerChannel { + check := make(chan bool) + stop := make(chan bool) + return &consumerChannel{name: name, check: check, stop: stop} +} diff --git a/supervisor/supervisor_test.go b/supervisor/supervisor_test.go new file mode 100644 index 0000000..ed4027f --- /dev/null +++ b/supervisor/supervisor_test.go @@ -0,0 +1,113 @@ +package supervisor + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/AirHelp/rabbit-amazon-forwarder/consumer" + "github.com/AirHelp/rabbit-amazon-forwarder/forwarder" +) + +func TestStart(t *testing.T) { + supervisor := New(prepareConsumers()) + if err := supervisor.Start(); err != nil { + t.Error("could not start supervised consumer->forwader pairs, error: ", err.Error()) + } + if len(supervisor.consumers) != 2 { + t.Errorf("wrong number of consumer-forwarder pairs, expected:%d, got:%d: ", 2, len(supervisor.consumers)) + } +} + +func TestRestart(t *testing.T) { + supervisor := New(prepareConsumers()) + req, err := http.NewRequest("GET", "/restart", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(supervisor.Restart) + + handler.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Errorf("wrong status code, expected:%d, got:%d", rr.Code, 200) + } + if rr.Body.String() != "success" { + t.Errorf("wrong response body, expected:%s, got:%v", "success", rr.Body.String()) + } +} + +func TestCheck(t *testing.T) { + supervisor := New(prepareConsumers()) + if err := supervisor.Start(); err != nil { + t.Error("could not start supervised consumer->forwader pairs, error: ", err.Error()) + } + req, err := http.NewRequest("GET", "/check", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := http.HandlerFunc(supervisor.Check) + + handler.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Errorf("wrong status code, expected:%d, got:%d", rr.Code, 200) + } + if rr.Body.String() != "success" { + t.Errorf("wrong response body, expected:%s, got:%v", "success", rr.Body.String()) + } +} + +func prepareConsumers() map[consumer.Client]forwarder.Client { + consumers := make(map[consumer.Client]forwarder.Client) + consumers[MockRabbitConsumer{"rabbit1"}] = MockSNSForwarder{"sns"} + consumers[MockRabbitConsumer{"rabbit2"}] = MockSQSForwarder{"sqs"} + return consumers +} + +type MockRabbitConsumer struct { + name string +} + +type MockSNSForwarder struct { + name string +} + +type MockSQSForwarder struct { + name string +} + +func (c MockRabbitConsumer) Name() string { + return c.name +} + +func (c MockRabbitConsumer) Start(client forwarder.Client, check chan bool, stop chan bool) error { + go func() { + for { + select { + case <-check: + fmt.Print("Checked") + } + } + }() + return nil +} + +func (f MockSNSForwarder) Name() string { + return f.name +} + +func (f MockSNSForwarder) Push(message string) error { + return nil +} + +func (f MockSQSForwarder) Name() string { + return f.name +} + +func (f MockSQSForwarder) Push(message string) error { + return nil +}