From c218db1547d4b47f1a1f90d531913e79ba5c321b Mon Sep 17 00:00:00 2001 From: maccac Date: Wed, 28 Nov 2018 06:43:23 +1100 Subject: [PATCH] Tls support (#28) * Adding support for amqp tls connections * Correcting spacing in the config.go file * Checking what travis ci gives as a docker version * Working around a docker-compose bug that prevents empty defaults * removed docker-compose version check This was being used for debugging a build failure and is no longer needed * Making a rabbit mq connector that should be more testable in theoryr * Refactors and initial tests around CreateConnector - Added ginkgo and gomega for testing - Test that a connector of the correct type is created - Moved the instantiation of the connector into mapping.go - Added a test to mapping.go to assert that the created connector is non nil - Renamed config on TlsRabbitConnector to config to reduce confusion around the config package * Adding some happy case tests * Removed some debug calls that were added by accident * Adding more test coverage * Adding test coverage around the tls connector * Switch to using HasPrefix and add additional test coverage around the basic connector --- Dockerfile | 1 + Gopkg.lock | 158 +++++++++++++++++- Gopkg.toml | 4 + README.md | 14 ++ config/config.go | 3 + connector/connector.go | 142 ++++++++++++++++ connector/connector_suite_test.go | 13 ++ connector/connector_test.go | 260 ++++++++++++++++++++++++++++++ docker-compose.yml | 4 + mapping/mapping.go | 4 +- mapping/mapping_test.go | 4 + rabbitmq/consumer.go | 27 ++-- 12 files changed, 617 insertions(+), 17 deletions(-) create mode 100644 connector/connector.go create mode 100644 connector/connector_suite_test.go create mode 100644 connector/connector_test.go diff --git a/Dockerfile b/Dockerfile index 469066f..d782b7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o rabbit-amazon-for FROM alpine:3.8 RUN mkdir -p /config +RUN mkdir -p /certs RUN apk --update upgrade && \ apk add curl ca-certificates && \ update-ca-certificates && \ diff --git a/Gopkg.lock b/Gopkg.lock index 58f73dc..c5aecf5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,6 +2,7 @@ [[projects]] + digest = "1:73c2c04ab63401fddec29dc221b1af00051c40a685b3acf5afd23325d7e29543" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -38,52 +39,203 @@ "service/sns/snsiface", "service/sqs", "service/sqs/sqsiface", - "service/sts" + "service/sts", ] + pruneopts = "UT" revision = "2fa5290a6a8f6a664f2dab5337d5f64d0cfd8f68" version = "v1.14.7" [[projects]] + digest = "1:fb46255681497314debedde38b64be32a75bae50bad107586c22f1662bf2d352" name = "github.com/go-ini/ini" packages = ["."] + pruneopts = "UT" revision = "06f5f3d67269ccec1fe5fe4134ba6e982984f7f5" version = "v1.37.0" [[projects]] + digest = "1:da7a0665d373e11bbc9d5f2015ce01cc8aa610ebf38de553e056f8b65c87243a" + name = "github.com/hpcloud/tail" + packages = [ + ".", + "ratelimiter", + "util", + "watch", + "winfile", + ] + pruneopts = "UT" + revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5" + version = "v1.0.0" + +[[projects]] + digest = "1:e22af8c7518e1eab6f2eab2b7d7558927f816262586cd6ed9f349c97a6c285c4" name = "github.com/jmespath/go-jmespath" packages = ["."] + pruneopts = "UT" revision = "0b12d6b5" [[projects]] + digest = "1:b1fc68069736db92aefa48adf2dcb7e16d23d13425c8ae80ca600281f7d711b3" + name = "github.com/onsi/ginkgo" + packages = [ + ".", + "config", + "internal/codelocation", + "internal/containernode", + "internal/failer", + "internal/leafnodes", + "internal/remote", + "internal/spec", + "internal/spec_iterator", + "internal/specrunner", + "internal/suite", + "internal/testingtproxy", + "internal/writer", + "reporters", + "reporters/stenographer", + "reporters/stenographer/support/go-colorable", + "reporters/stenographer/support/go-isatty", + "types", + ] + pruneopts = "UT" + revision = "3774a09d95489ccaa16032e0770d08ea77ba6184" + version = "v1.6.0" + +[[projects]] + digest = "1:c4a34ca1d26b1d08ba9199b3c187f9c0b090042b7bdb4e5bf9ad76078a3593ac" + name = "github.com/onsi/gomega" + packages = [ + ".", + "format", + "internal/assertion", + "internal/asyncassertion", + "internal/oraclematcher", + "internal/testingtsupport", + "matchers", + "matchers/support/goraph/bipartitegraph", + "matchers/support/goraph/edge", + "matchers/support/goraph/node", + "matchers/support/goraph/util", + "types", + ] + pruneopts = "UT" + revision = "b6ea1ea48f981d0f615a154a45eabb9dd466556d" + version = "v1.4.1" + +[[projects]] + digest = "1:9e9193aa51197513b3abcb108970d831fbcf40ef96aa845c4f03276e1fa316d2" name = "github.com/sirupsen/logrus" packages = ["."] + pruneopts = "UT" revision = "c155da19408a8799da419ed3eeb0cb5db0ad5dbc" version = "v1.0.5" [[projects]] branch = "master" + digest = "1:08e00568d99ec12096ba60887632eb2b94ed8b3c23e2ed90eb263e12eacf8f3a" name = "github.com/streadway/amqp" packages = ["."] + pruneopts = "UT" revision = "e5adc2ada8b8efff032bf61173a233d143e9318e" [[projects]] branch = "master" + digest = "1:3f3a05ae0b95893d90b9b3b5afdb79a9b3d96e4e36e099d841ae602e4aca0da8" name = "golang.org/x/crypto" packages = ["ssh/terminal"] + pruneopts = "UT" revision = "a8fb68e7206f8c78be19b432c58eb52a6aa34462" [[projects]] branch = "master" + digest = "1:6bab46a020b433450836f84254204364dbce7da8fd944a8c2c40e294b4160826" + name = "golang.org/x/net" + packages = [ + "html", + "html/atom", + "html/charset", + ] + pruneopts = "UT" + revision = "a680a1efc54dd51c040b3b5ce4939ea3cf2ea0d1" + +[[projects]] + branch = "master" + digest = "1:b33745a6adfa9879da2f5bd8450cea94c78041b00cf1a9d3bab7bf2f1769d3dd" name = "golang.org/x/sys" packages = [ "unix", - "windows" + "windows", ] + pruneopts = "UT" revision = "8883426083c04a2627e6e59d84d5f6fb63d16c91" +[[projects]] + digest = "1:00b52d21f87065cee85096f07ba40957cc53719145c224da5f5d113669868213" + name = "golang.org/x/text" + packages = [ + "encoding", + "encoding/charmap", + "encoding/htmlindex", + "encoding/internal", + "encoding/internal/identifier", + "encoding/japanese", + "encoding/korean", + "encoding/simplifiedchinese", + "encoding/traditionalchinese", + "encoding/unicode", + "internal/gen", + "internal/tag", + "internal/utf8internal", + "language", + "runes", + "transform", + "unicode/cldr", + ] + pruneopts = "UT" + revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" + version = "v0.3.0" + +[[projects]] + digest = "1:abeb38ade3f32a92943e5be54f55ed6d6e3b6602761d74b4aab4c9dd45c18abd" + name = "gopkg.in/fsnotify.v1" + packages = ["."] + pruneopts = "UT" + revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" + source = "https://github.com/fsnotify/fsnotify/tree/v1.4.7" + version = "v1.4.7" + +[[projects]] + branch = "v1" + digest = "1:0caa92e17bc0b65a98c63e5bc76a9e844cd5e56493f8fdbb28fad101a16254d9" + name = "gopkg.in/tomb.v1" + packages = ["."] + pruneopts = "UT" + revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8" + +[[projects]] + digest = "1:342378ac4dcb378a5448dd723f0784ae519383532f5e70ade24132c4c8693202" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "UT" + revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183" + version = "v2.2.1" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "d99dfc0229e80b714bd1f94efc4f1dc9a1114e56a4dadd1ac69300f0f399af92" + input-imports = [ + "github.com/aws/aws-sdk-go/aws", + "github.com/aws/aws-sdk-go/aws/session", + "github.com/aws/aws-sdk-go/service/lambda", + "github.com/aws/aws-sdk-go/service/lambda/lambdaiface", + "github.com/aws/aws-sdk-go/service/sns", + "github.com/aws/aws-sdk-go/service/sns/snsiface", + "github.com/aws/aws-sdk-go/service/sqs", + "github.com/aws/aws-sdk-go/service/sqs/sqsiface", + "github.com/onsi/ginkgo", + "github.com/onsi/gomega", + "github.com/sirupsen/logrus", + "github.com/streadway/amqp", + ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index fa25c4a..5c81a3f 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -37,6 +37,10 @@ branch = "master" name = "github.com/streadway/amqp" +[[override]] + name = "gopkg.in/fsnotify.v1" + source = "https://github.com/fsnotify/fsnotify/tree/v1.4.7" + [prune] go-tests = true unused-packages = true diff --git a/README.md b/README.md index ca8ff14..809b347 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,20 @@ export AWS_ACCESS_KEY_ID=access_key export AWS_SECRET_ACCESS_KEY=secret_key ``` +#### Using TLS with rabbit + +Specify amqps for the rabbit connection ub the mapping file: +``` + "connection" : "amqps://guest:guest@localhost:5671/", +``` + +Additional environment variables for working with TLS and rabbit: +``` +export CA_CERT=/certs/ca_certificate.pem +export CERT_FILE=/certs/client_certificate.pem +export KEY_FILE=/certs/client_key.pem +``` + ### Amazon configuration When making subscription to SNS -> SQS/HTTP/HTTPS set `Raw message delivery` to ensure that json messages are not escaped. diff --git a/config/config.go b/config/config.go index 73d2db9..d980b85 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,9 @@ package config const ( // MappingFile mapping file environment variable MappingFile = "MAPPING_FILE" + CaCertFile = "CA_CERT_FILE" + CertFile = "CERT_FILE" + KeyFile = "KEY_FILE" ) // RabbitEntry RabbitMQ mapping entry diff --git a/connector/connector.go b/connector/connector.go new file mode 100644 index 0000000..2c8b196 --- /dev/null +++ b/connector/connector.go @@ -0,0 +1,142 @@ +package connector + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "os" + "strings" + + "github.com/AirHelp/rabbit-amazon-forwarder/config" + log "github.com/sirupsen/logrus" + + "github.com/streadway/amqp" +) + +type FileReader interface { + ReadFile(filename string) ([]byte, error) +} + +type IOFileReader struct { +} + +func (i *IOFileReader) ReadFile(filename string) ([]byte, error) { + return ioutil.ReadFile(filename) +} + +type CertPoolMaker interface { + NewCertPoolWithAppendedCa(caCert []byte) *x509.CertPool +} + +type X509CertPoolMaker struct { +} + +func (x *X509CertPoolMaker) NewCertPoolWithAppendedCa(caCert []byte) *x509.CertPool { + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(caCert) + return certPool +} + +type KeyLoader interface { + LoadKeyPair(certFile, keyFile string) (tls.Certificate, error) +} + +type X509KeyPairLoader struct { +} + +func (x *X509KeyPairLoader) LoadKeyPair(certFile string, keyFile string) (tls.Certificate, error) { + return tls.LoadX509KeyPair(certFile, keyFile) +} + +type RabbitDialer interface { + Dial(connectionURL string) (*amqp.Connection, error) +} + +type BasicRabbitDialer struct { +} + +func (s *BasicRabbitDialer) Dial(connectionURL string) (*amqp.Connection, error) { + return amqp.Dial(connectionURL) +} + +type TlsRabbitDialer interface { + DialTLS(connectionURL string, tlsConfig *tls.Config) (*amqp.Connection, error) +} + +type X509TlsDialer struct { +} + +func (s *X509TlsDialer) DialTLS(connectionURL string, tlsConfig *tls.Config) (*amqp.Connection, error) { + return amqp.DialTLS(connectionURL, tlsConfig) +} + +type RabbitConnector interface { + CreateConnection(connectionURL string) (*amqp.Connection, error) +} + +type BasicRabbitConnector struct { + BasicRabbitDialer RabbitDialer +} + +func (c *BasicRabbitConnector) CreateConnection(connectionURL string) (*amqp.Connection, error) { + log.Info("Dialing in") + return c.BasicRabbitDialer.Dial(connectionURL) +} + +type TlsRabbitConnector struct { + TlsConfig *tls.Config + FileReader FileReader + CertPoolMaker CertPoolMaker + KeyLoader KeyLoader + TlsDialer TlsRabbitDialer +} + +func (c *TlsRabbitConnector) CreateConnection(connectionURL string) (*amqp.Connection, error) { + log.Info("Dialing in via TLS") + caCertFilePath := os.Getenv(config.CaCertFile) + + if ca, err := c.FileReader.ReadFile(caCertFilePath); err == nil { + c.TlsConfig.RootCAs = c.CertPoolMaker.NewCertPoolWithAppendedCa(ca) + } else { + log.WithFields(log.Fields{ + "error": err.Error(), + config.CaCertFile: caCertFilePath}).Info("Error loading CA Cert file") + return nil, err + } + + certFilePath := os.Getenv(config.CertFile) + keyFilePath := os.Getenv(config.KeyFile) + if cert, err := c.KeyLoader.LoadKeyPair(certFilePath, keyFilePath); err == nil { + c.TlsConfig.Certificates = append(c.TlsConfig.Certificates, cert) + } else { + log.WithFields(log.Fields{ + "error": err.Error(), + config.CertFile: certFilePath, + config.KeyFile: keyFilePath}).Info("Error loading client certificates") + } + return c.TlsDialer.DialTLS(connectionURL, c.TlsConfig) +} + +func CreateBasicRabbitConnector() *BasicRabbitConnector { + return &BasicRabbitConnector{ + BasicRabbitDialer: &BasicRabbitDialer{}, + } +} + +func CreateTlsRabbitConnector() *TlsRabbitConnector { + return &TlsRabbitConnector{ + TlsConfig: new(tls.Config), + FileReader: &IOFileReader{}, + CertPoolMaker: &X509CertPoolMaker{}, + KeyLoader: &X509KeyPairLoader{}, + TlsDialer: &X509TlsDialer{}, + } +} + +func CreateConnector(connectionURL string) RabbitConnector { + if strings.HasPrefix(connectionURL, "amqps") { + return CreateTlsRabbitConnector() + } else { + return CreateBasicRabbitConnector() + } +} diff --git a/connector/connector_suite_test.go b/connector/connector_suite_test.go new file mode 100644 index 0000000..2ca6b26 --- /dev/null +++ b/connector/connector_suite_test.go @@ -0,0 +1,13 @@ +package connector_test + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestConnector(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Connector Suite") +} diff --git a/connector/connector_test.go b/connector/connector_test.go new file mode 100644 index 0000000..c0cec3b --- /dev/null +++ b/connector/connector_test.go @@ -0,0 +1,260 @@ +package connector_test + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "os" + + "github.com/AirHelp/rabbit-amazon-forwarder/config" + "github.com/AirHelp/rabbit-amazon-forwarder/connector" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/streadway/amqp" +) + +var _ = Describe("Connector", func() { + + Describe("Creating connectors", func() { + Context("With a basic rabbit configuration", func() { + It("should be a BasicRabbitConnector", func() { + actualConnect := connector.CreateConnector("amqp") + Expect(actualConnect).Should(BeAssignableToTypeOf(&connector.BasicRabbitConnector{})) + }) + }) + + Context("With an amqps value somewhere else in the connection url", func() { + It("should be a BasicRabbitConnector", func() { + actualConnect := connector.CreateConnector("amqp://guest:guest@rabbbit-amqps:5672") + Expect(actualConnect).Should(BeAssignableToTypeOf(&connector.BasicRabbitConnector{})) + }) + }) + + Context("With a tls rabbit configuration", func() { + It("should be a TlsRabbitConnector", func() { + actualConnect := connector.CreateConnector("amqps") + Expect(actualConnect).Should(BeAssignableToTypeOf(&connector.TlsRabbitConnector{})) + }) + }) + }) + + Describe("Connecting a basic rabbit connector", func() { + + var ( + rabbitConnector connector.RabbitConnector + dialer *MockBasicRabbitDialer + ) + + BeforeEach(func() { + dialer = &MockBasicRabbitDialer{} + rabbitConnector = createBasicConnector(dialer) + }) + + Context("With no problems creating the connection", func() { + It("Should create a connection", func() { + expectedConnection := createDummyAmqpConnection() + dialer.ReturnedConnection = expectedConnection + + connection, err := rabbitConnector.CreateConnection("any amqp url") + + Expect(connection).Should(Equal(expectedConnection)) + Expect(err).Should(BeNil()) + }) + }) + + Context("With an error creating the connection", func() { + It("Should return an error", func() { + dialer.Error = errors.New("Expected") + dialer.ReturnedConnection = nil + + connection, err := rabbitConnector.CreateConnection("any amqp url") + + Expect(connection).Should(BeNil()) + Expect(err).Should(Equal(dialer.Error)) + }) + }) + }) + + Describe("Connecting a tls rabbit connector", func() { + + var ( + rabbitConnector connector.RabbitConnector + fileReader *MockFileReader + dialer *MockTlsRabbitDialer + tlsConfig *tls.Config + certPoolMaker *MockCertPoolMaker + keyLoader *MockkeyLoader + ) + + BeforeEach(func() { + os.Setenv(config.CaCertFile, "CaName") + os.Setenv(config.CertFile, "CertFile") + os.Setenv(config.KeyFile, "KeyFile") + + dialer = &MockTlsRabbitDialer{} + fileReader = &MockFileReader{ + Error: nil, + DummyFile: []byte("Dummy file"), + } + tlsConfig = new(tls.Config) + certPoolMaker = &MockCertPoolMaker{ + CertPoolToReturn: x509.NewCertPool(), + } + keyLoader = &MockkeyLoader{ + ReturnedCertificate: tls.Certificate{}, + } + rabbitConnector = createTlsConnector(dialer, fileReader, tlsConfig, certPoolMaker, keyLoader) + }) + + Context("With no problems creating the connection", func() { + It("Should create a connection", func() { + expectedConnection := createDummyAmqpConnection() + dialer.ReturnedConnection = expectedConnection + + connection, err := rabbitConnector.CreateConnection("any amqps url") + + // assert that file reader loaded the + Expect(fileReader.FileNameRead).Should(Equal("CaName")) + + // asert that ca is added to root ca + Expect(certPoolMaker.AppendedCaCert).Should(Equal([]byte("Dummy file"))) + Expect(tlsConfig.RootCAs).Should(Equal(certPoolMaker.CertPoolToReturn)) + + // assert that client certifcate is added + Expect(keyLoader.CertFileProvided).Should(Equal("CertFile")) + Expect(keyLoader.KeyFileProvided).Should(Equal("KeyFile")) + Expect(tlsConfig.Certificates).Should(ContainElement(keyLoader.ReturnedCertificate)) + + //assert that connection is created with correct params + Expect(dialer.ConnectionUrlProvided).Should(Equal("any amqps url")) + Expect(dialer.TlsConfigProvided).Should(Equal(tlsConfig)) + + //assert that the connection is returned + Expect(connection).Should(Equal(expectedConnection)) + Expect(err).Should(BeNil()) + }) + }) + + Context("With an error loading the ca certificate", func() { + It("Should return an error", func() { + fileReader.Error = errors.New("Expected") + fileReader.DummyFile = nil + + connection, err := rabbitConnector.CreateConnection("any amqp url") + + Expect(connection).Should(BeNil()) + Expect(err).Should(Equal(fileReader.Error)) + }) + }) + + Context("With an error loading client certificates", func() { + It("Should proceed with creating the connection", func() { + // We can leave the error handling to the TLS protocol + // and log an error indicating that no keys were loaded + var nilCertificate tls.Certificate + expectedConnection := createDummyAmqpConnection() + dialer.ReturnedConnection = expectedConnection + keyLoader.ReturnedCertificate = nilCertificate + keyLoader.Error = errors.New("Expected") + + connection, err := rabbitConnector.CreateConnection("any amqps url") + + // assert that client certifcate is added + Expect(len(tlsConfig.Certificates)).Should(Equal(0)) + + //assert that connection is created with correct params + Expect(dialer.ConnectionUrlProvided).Should(Equal("any amqps url")) + Expect(dialer.TlsConfigProvided).Should(Equal(tlsConfig)) + + //assert that the connection is returned + Expect(connection).Should(Equal(expectedConnection)) + Expect(err).Should(BeNil()) + }) + }) + }) +}) + +func createBasicConnector(mockDialer connector.RabbitDialer) *connector.BasicRabbitConnector { + return &connector.BasicRabbitConnector{ + BasicRabbitDialer: mockDialer, + } +} + +func createTlsConnector( + mockDialer connector.TlsRabbitDialer, + mockFileReader connector.FileReader, + tlsConfig *tls.Config, + certPoolMaker connector.CertPoolMaker, + keyLoader connector.KeyLoader) *connector.TlsRabbitConnector { + return &connector.TlsRabbitConnector{ + TlsConfig: tlsConfig, + FileReader: mockFileReader, + CertPoolMaker: certPoolMaker, + KeyLoader: keyLoader, + TlsDialer: mockDialer, + } +} + +type MockFileReader struct { + FileNameRead string + Error error + DummyFile []byte +} + +func (i *MockFileReader) ReadFile(filename string) ([]byte, error) { + i.FileNameRead = filename + return i.DummyFile, i.Error +} + +type MockkeyLoader struct { + CertFileProvided string + KeyFileProvided string + ReturnedCertificate tls.Certificate + Error error +} + +func (x *MockkeyLoader) LoadKeyPair(certFile string, keyFile string) (tls.Certificate, error) { + x.CertFileProvided = certFile + x.KeyFileProvided = keyFile + return x.ReturnedCertificate, x.Error +} + +type MockTlsRabbitDialer struct { + ConnectionUrlProvided string + TlsConfigProvided *tls.Config + ReturnedConnection *amqp.Connection + Error error +} + +func (s *MockTlsRabbitDialer) DialTLS(connectionURL string, tlsConfig *tls.Config) (*amqp.Connection, error) { + s.ConnectionUrlProvided = connectionURL + s.TlsConfigProvided = tlsConfig + return s.ReturnedConnection, s.Error +} + +type MockBasicRabbitDialer struct { + Called bool + ReturnedConnection *amqp.Connection + Error error +} + +func (s *MockBasicRabbitDialer) Dial(connectionURL string) (*amqp.Connection, error) { + s.Called = true + return s.ReturnedConnection, s.Error +} + +type MockCertPoolMaker struct { + Called bool + AppendedCaCert []byte + CertPoolToReturn *x509.CertPool +} + +func (x *MockCertPoolMaker) NewCertPoolWithAppendedCa(caCert []byte) *x509.CertPool { + x.AppendedCaCert = caCert + x.Called = true + return x.CertPoolToReturn +} + +func createDummyAmqpConnection() *amqp.Connection { + return &amqp.Connection{} +} diff --git a/docker-compose.yml b/docker-compose.yml index 7b7abc7..5e196e1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,11 +6,15 @@ services: - "8080:8080" volumes: - "${MAPPING_FILE}:/config/mapping.json" + - "${CERTS_DIR:-./certs}:/certs" environment: MAPPING_FILE: /config/mapping.json AWS_REGION: ${AWS_REGION} AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} + CA_CERT_FILE: ${CA_CERT_FILE:- } + CERT_FILE: ${CERT_FILE:- } + KEY_FILE: ${KEY_FILE:- } tests: build: context: . diff --git a/mapping/mapping.go b/mapping/mapping.go index 2ce7e7d..0399019 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" + "github.com/AirHelp/rabbit-amazon-forwarder/connector" log "github.com/sirupsen/logrus" "github.com/AirHelp/rabbit-amazon-forwarder/config" @@ -84,7 +85,8 @@ func (h helperImpl) createConsumer(entry config.RabbitEntry) consumer.Client { "consumerName": entry.Name}).Info("Creating consumer") switch entry.Type { case rabbitmq.Type: - return rabbitmq.CreateConsumer(entry) + rabbitConnector := connector.CreateConnector(entry.ConnectionURL) + return rabbitmq.CreateConsumer(entry, rabbitConnector) } return nil } diff --git a/mapping/mapping_test.go b/mapping/mapping_test.go index dc36b92..cfaa4d2 100644 --- a/mapping/mapping_test.go +++ b/mapping/mapping_test.go @@ -57,6 +57,10 @@ func TestCreateConsumer(t *testing.T) { if consumer.Name() != consumerName { t.Errorf("wrong consumer name, expected %s, found %s", consumerName, consumer.Name()) } + rabbitConsumer := consumer.(rabbitmq.Consumer) + if rabbitConsumer.RabbitConnector == nil { + t.Errorf("rabbit consumer should have been set") + } } func TestCreateForwarderSNS(t *testing.T) { diff --git a/rabbitmq/consumer.go b/rabbitmq/consumer.go index 2befba7..92d2e58 100644 --- a/rabbitmq/consumer.go +++ b/rabbitmq/consumer.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/AirHelp/rabbit-amazon-forwarder/config" + "github.com/AirHelp/rabbit-amazon-forwarder/connector" "github.com/AirHelp/rabbit-amazon-forwarder/consumer" "github.com/AirHelp/rabbit-amazon-forwarder/forwarder" "github.com/streadway/amqp" @@ -24,11 +25,12 @@ const ( // Consumer implementation or RabbitMQ consumer type Consumer struct { - name string - ConnectionURL string - ExchangeName string - QueueName string - RoutingKeys []string + name string + ConnectionURL string + ExchangeName string + QueueName string + RoutingKeys []string + RabbitConnector connector.RabbitConnector } // parameters for starting consumer @@ -42,13 +44,12 @@ type workerParams struct { } // CreateConsumer creates consumer from string map -func CreateConsumer(entry config.RabbitEntry) consumer.Client { - // merge RoutingKey with RoutingKeys - if entry.RoutingKey != "" { - entry.RoutingKeys = append(entry.RoutingKeys, entry.RoutingKey) - } - - return Consumer{entry.Name, entry.ConnectionURL, entry.ExchangeName, entry.QueueName, entry.RoutingKeys} +func CreateConsumer(entry config.RabbitEntry, rabbitConnector connector.RabbitConnector) consumer.Client { + // merge RoutingKey with RoutingKeys + if entry.RoutingKey != "" { + entry.RoutingKeys = append(entry.RoutingKeys, entry.RoutingKey) + } + return Consumer{entry.Name, entry.ConnectionURL, entry.ExchangeName, entry.QueueName, entry.RoutingKeys, rabbitConnector} } // Name consumer name @@ -101,7 +102,7 @@ func (c Consumer) initRabbitMQ() (<-chan amqp.Delivery, *amqp.Connection, *amqp. } func (c Consumer) connect() (<-chan amqp.Delivery, *amqp.Connection, *amqp.Channel, error) { - conn, err := amqp.Dial(c.ConnectionURL) + conn, err := c.RabbitConnector.CreateConnection(c.ConnectionURL) if err != nil { return failOnError(err, "Failed to connect to RabbitMQ") }