Skip to content

Commit

Permalink
Merge pull request #42 from aryszka/master
Browse files Browse the repository at this point in the history
support multiple routing tables in a zone
  • Loading branch information
szuecs authored Jun 23, 2021
2 parents 2ffcf9f + 9cd4950 commit 0a658d3
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 61 deletions.
205 changes: 151 additions & 54 deletions provider/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net"
"sort"
"strings"
"time"

Expand Down Expand Up @@ -62,6 +63,17 @@ type AWSProvider struct {
logger *log.Entry
}

type stackSpec struct {
name string
vpcID string
internetGatewayID string
tableID map[string]string
timeoutInMinutes uint
template string
stackTerminationProtection bool
tags []*cloudformation.Tag
}

func NewAWSProvider(clusterID, controllerID string, dry bool, vpcID string, natCidrBlocks, availabilityZones []string, stackTerminationProtection bool, additionalStackTags map[string]string) *AWSProvider {
// TODO: find vpcID at startup
p := defaultConfigProvider()
Expand Down Expand Up @@ -189,11 +201,19 @@ func getCIDRsFromTemplate(template string) map[string]struct{} {
return cidrs
}

func findTagByKey(tags []*ec2.Tag, key string) string {
for _, t := range tags {
if aws.StringValue(t.Key) == key {
return aws.StringValue(t.Value)
}
}

return ""
}

func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string]*net.IPNet) (*stackSpec, error) {
spec := &stackSpec{
name: normalizeStackName(p.clusterID),
template: p.generateTemplate(configs),
tableID: make(map[string]string),
timeoutInMinutes: 10,
stackTerminationProtection: p.stackTerminationProtection,
}
Expand Down Expand Up @@ -232,18 +252,97 @@ func (p *AWSProvider) generateStackSpec(configs map[provider.Resource]map[string
return nil, err
}

// adding route tables to spec
for _, table := range rt {
for _, tag := range table.Tags {
if tagDefaultAZKeyRouteTableID == aws.StringValue(tag.Key) {
// eu-central-1a -> rtb-b738aadc
spec.tableID[aws.StringValue(tag.Value)] = aws.StringValue(table.RouteTableId)
}
// [supporting multiple routing tables]
// as a migration step, in order to preserve the current indexes CloudFormation template names of the
// route-to-nat resources, we need to order them first, they are only identifiable by their standard
// name: dmz-eu-central-1a.
sort.SliceStable(rt, func(i, j int) bool {
rti, rtj := rt[i], rt[j]

zonei, ok := routeTableZone(rti)
if !ok {
return false
}

zonej, ok := routeTableZone(rtj)
if !ok {
return true
}

namei := findTagByKey(rti.Tags, "Name")
namej := findTagByKey(rtj.Tags, "Name")
standardi := namei == fmt.Sprintf("%s-%s", tagDefaultTypeValueRouteTableID, zonei)
standardj := namej == fmt.Sprintf("%s-%s", tagDefaultTypeValueRouteTableID, zonej)

if !standardi {
return false
}

if standardi && !standardj {
return true
}

zoneIndexi, ok := zoneIndex(p.availabilityZones, zonei)
if !ok {
return false
}

zoneIndexj, ok := zoneIndex(p.availabilityZones, zonej)
if !ok {
return true
}

return zoneIndexi < zoneIndexj
})

var paramOrder []string
tableZoneIndexes := make(map[string]int)
tableID := make(map[string]string)
for i, table := range rt {
zone, ok := routeTableZone(table)
if !ok {
continue
}

zindex, ok := zoneIndex(p.availabilityZones, zone)
if !ok {
return nil, fmt.Errorf(
"unrecognized availability zone in routing table tags: %s",
zone,
)
}

paramName := fmt.Sprintf("AZ%dRouteTableIDParameter", i+1)
paramOrder = append(paramOrder, paramName)
tableZoneIndexes[paramName] = zindex
tableID[paramName] = aws.StringValue(table.RouteTableId)
}

spec.template = p.generateTemplate(configs, paramOrder, tableZoneIndexes)
spec.tableID = tableID
return spec, nil
}

func routeTableZone(rt *ec2.RouteTable) (string, bool) {
for _, tag := range rt.Tags {
if tagDefaultAZKeyRouteTableID == aws.StringValue(tag.Key) {
return aws.StringValue(tag.Value), true
}
}

return "", false
}

func zoneIndex(zones []string, zone string) (int, bool) {
for i, z := range zones {
if z == zone {
return i, true
}
}

return 0, false
}

func (p *AWSProvider) findVPC() (string, error) {
// provided by the user
if p.vpcID != "" {
Expand All @@ -269,18 +368,11 @@ func (p *AWSProvider) findVPC() (string, error) {
return "", fmt.Errorf("VPC not found")
}

type stackSpec struct {
name string
vpcID string
internetGatewayID string
tableID map[string]string
timeoutInMinutes uint
template string
stackTerminationProtection bool
tags []*cloudformation.Tag
}

func (p *AWSProvider) generateTemplate(configs map[provider.Resource]map[string]*net.IPNet) string {
func (p *AWSProvider) generateTemplate(
configs map[provider.Resource]map[string]*net.IPNet,
routeTableParamOrder []string,
routeTableZoneIndexes map[string]int,
) string {
template := cft.NewTemplate()
template.Description = "Static Egress Stack"
template.Outputs = map[string]*cft.Output{}
Expand All @@ -294,11 +386,6 @@ func (p *AWSProvider) generateTemplate(configs map[provider.Resource]map[string]
}

for i := 1; i <= len(p.availabilityZones); i++ {
template.Parameters[fmt.Sprintf("AZ%dRouteTableIDParameter", i)] = &cft.Parameter{
Description: fmt.Sprintf(
"Route Table ID Availability Zone %d", i),
Type: "String",
}
template.AddResource(fmt.Sprintf("NATGateway%d", i), &cft.EC2NatGateway{
SubnetId: cft.Ref(
fmt.Sprintf("NATSubnet%d", i)).String(),
Expand Down Expand Up @@ -351,21 +438,26 @@ func (p *AWSProvider) generateTemplate(configs map[provider.Resource]map[string]
}

nets := provider.GenerateRoutes(configs)

for cidrEntry := range nets {
cleanCidrEntry := strings.Replace(cidrEntry, "/", "y", -1)
cleanCidrEntry = strings.Replace(cleanCidrEntry, ".", "x", -1)
for i := 1; i <= len(p.availabilityZones); i++ {
p.logger.Debugf("RouteToNAT%dz%s", i, cleanCidrEntry)
template.AddResource(fmt.Sprintf("RouteToNAT%dz%s", i, cleanCidrEntry), &cft.EC2Route{
RouteTableId: cft.Ref(
fmt.Sprintf("AZ%dRouteTableIDParameter", i)).String(),
for i, routeTableParam := range routeTableParamOrder {
template.Parameters[routeTableParam] = &cft.Parameter{
Description: fmt.Sprintf("Route Table ID No %d", i+1),
Type: "String",
}

template.AddResource(fmt.Sprintf("RouteToNAT%dz%s", i+1, cleanCidrEntry), &cft.EC2Route{
RouteTableId: cft.Ref(routeTableParam).String(),
DestinationCidrBlock: cft.String(cidrEntry),
NatGatewayId: cft.Ref(
fmt.Sprintf("NATGateway%d", i)).String(),
NatGatewayId: cft.Ref(fmt.Sprintf(
"NATGateway%d",
routeTableZoneIndexes[routeTableParam]+1,
)).String(),
})
}
}

stack, _ := json.Marshal(template)
return string(stack)
}
Expand Down Expand Up @@ -424,19 +516,17 @@ func (p *AWSProvider) deleteCFStack(stackName string) error {
func (p *AWSProvider) updateCFStack(spec *stackSpec) error {
params := &cloudformation.UpdateStackInput{
StackName: aws.String(spec.name),
Parameters: []*cloudformation.Parameter{
cfParam(parameterVPCIDParameter, spec.vpcID),
cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID),
},
Parameters: append(
[]*cloudformation.Parameter{
cfParam(parameterVPCIDParameter, spec.vpcID),
cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID),
},
routeTableParams(spec)...,
),
TemplateBody: aws.String(spec.template),
Tags: spec.tags,
}
for i, az := range p.availabilityZones {
params.Parameters = append(params.Parameters,
cfParam(
fmt.Sprintf("AZ%dRouteTableIDParameter", i+1),
spec.tableID[az]))
}

if !p.dry {
// ensure the stack termination protection is set
if spec.stackTerminationProtection {
Expand Down Expand Up @@ -475,21 +565,19 @@ func (p *AWSProvider) createCFStack(spec *stackSpec) error {
params := &cloudformation.CreateStackInput{
StackName: aws.String(spec.name),
OnFailure: aws.String(cloudformation.OnFailureDelete),
Parameters: []*cloudformation.Parameter{
cfParam(parameterVPCIDParameter, spec.vpcID),
cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID),
},
Parameters: append(
[]*cloudformation.Parameter{
cfParam(parameterVPCIDParameter, spec.vpcID),
cfParam(parameterInternetGatewayIDParameter, spec.internetGatewayID),
},
routeTableParams(spec)...,
),
TemplateBody: aws.String(spec.template),
TimeoutInMinutes: aws.Int64(int64(spec.timeoutInMinutes)),
EnableTerminationProtection: aws.Bool(spec.stackTerminationProtection),
Tags: spec.tags,
}
for i, az := range p.availabilityZones {
params.Parameters = append(params.Parameters,
cfParam(
fmt.Sprintf("AZ%dRouteTableIDParameter", i+1),
spec.tableID[az]))
}

if !p.dry {
_, err := p.cloudformation.CreateStack(params)
if err != nil {
Expand All @@ -512,6 +600,15 @@ func (p *AWSProvider) createCFStack(spec *stackSpec) error {

}

func routeTableParams(s *stackSpec) []*cloudformation.Parameter {
var params []*cloudformation.Parameter
for paramName, routeTableID := range s.tableID {
params = append(params, cfParam(paramName, routeTableID))
}

return params
}

func (p *AWSProvider) getStackByName(stackName string) (*cloudformation.Stack, error) {
params := &cloudformation.DescribeStacksInput{
StackName: aws.String(stackName),
Expand Down
32 changes: 25 additions & 7 deletions provider/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ func TestGenerateStackSpec(t *testing.T) {
if stackSpec.internetGatewayID != expectedInternetGatewayId {
t.Errorf("Expect: %s,\n but got %s", expectedInternetGatewayId, stackSpec.internetGatewayID)
}
if stackSpec.tableID["eu-central-1a"] != expectedRouteTableId {
t.Errorf("Expect: %s,\n but got %s", expectedRouteTableId, stackSpec.tableID["eu-central-1a"])
if stackSpec.tableID["AZ1RouteTableIDParameter"] != expectedRouteTableId {
t.Errorf(
"Expect: %s,\n but got %s",
expectedRouteTableId,
stackSpec.tableID["AZ1RouteTableIDParameter"],
)
}
// sort tags to ensure stable comparison
sort.Slice(stackSpec.tags, func(i, j int) bool {
Expand All @@ -139,10 +143,14 @@ func TestGenerateTemplate(t *testing.T) {
},
}
p := NewAWSProvider("cluster-x", "controller-x", true, "", natCidrBlocks, availabilityZones, false, nil)
expect := `{"AWSTemplateFormatVersion":"2010-09-09","Description":"Static Egress Stack","Parameters":{"AZ1RouteTableIDParameter":{"Type":"String","Description":"Route Table ID Availability Zone 1"},"InternetGatewayIDParameter":{"Type":"String","Description":"Internet Gateway ID"},"VPCIDParameter":{"Type":"AWS::EC2::VPC::Id","Description":"VPC ID"}},"Resources":{"EIP1":{"Type":"AWS::EC2::EIP","Properties":{"Domain":"vpc"}},"NATGateway1":{"Type":"AWS::EC2::NatGateway","Properties":{"AllocationId":{"Fn::GetAtt":["EIP1","AllocationId"]},"SubnetId":{"Ref":"NATSubnet1"}}},"NATSubnet1":{"Type":"AWS::EC2::Subnet","Properties":{"AvailabilityZone":"eu-central-1a","CidrBlock":"172.31.64.0/28","Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}],"VpcId":{"Ref":"VPCIDParameter"}}},"NATSubnetRoute1":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"0.0.0.0/0","GatewayId":{"Ref":"InternetGatewayIDParameter"},"RouteTableId":{"Ref":"NATSubnetRouteTable1"}}},"NATSubnetRouteTable1":{"Type":"AWS::EC2::RouteTable","Properties":{"VpcId":{"Ref":"VPCIDParameter"},"Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}]}},"NATSubnetRouteTableAssociation1":{"Type":"AWS::EC2::SubnetRouteTableAssociation","Properties":{"RouteTableId":{"Ref":"NATSubnetRouteTable1"},"SubnetId":{"Ref":"NATSubnet1"}}},"RouteToNAT1z213x95x138x236y32":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"213.95.138.236/32","NatGatewayId":{"Ref":"NATGateway1"},"RouteTableId":{"Ref":"AZ1RouteTableIDParameter"}}}},"Outputs":{"EIP1":{"Description":"external IP of the NATGateway1","Value":{"Ref":"EIP1"}}}}`
template := p.generateTemplate(destinationCidrBlocks)
expect := `{"AWSTemplateFormatVersion":"2010-09-09","Description":"Static Egress Stack","Parameters":{"AZ1RouteTableIDParameter":{"Type":"String","Description":"Route Table ID No 1"},"InternetGatewayIDParameter":{"Type":"String","Description":"Internet Gateway ID"},"VPCIDParameter":{"Type":"AWS::EC2::VPC::Id","Description":"VPC ID"}},"Resources":{"EIP1":{"Type":"AWS::EC2::EIP","Properties":{"Domain":"vpc"}},"NATGateway1":{"Type":"AWS::EC2::NatGateway","Properties":{"AllocationId":{"Fn::GetAtt":["EIP1","AllocationId"]},"SubnetId":{"Ref":"NATSubnet1"}}},"NATSubnet1":{"Type":"AWS::EC2::Subnet","Properties":{"AvailabilityZone":"eu-central-1a","CidrBlock":"172.31.64.0/28","Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}],"VpcId":{"Ref":"VPCIDParameter"}}},"NATSubnetRoute1":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"0.0.0.0/0","GatewayId":{"Ref":"InternetGatewayIDParameter"},"RouteTableId":{"Ref":"NATSubnetRouteTable1"}}},"NATSubnetRouteTable1":{"Type":"AWS::EC2::RouteTable","Properties":{"VpcId":{"Ref":"VPCIDParameter"},"Tags":[{"Key":"Name","Value":"nat-eu-central-1a"}]}},"NATSubnetRouteTableAssociation1":{"Type":"AWS::EC2::SubnetRouteTableAssociation","Properties":{"RouteTableId":{"Ref":"NATSubnetRouteTable1"},"SubnetId":{"Ref":"NATSubnet1"}}},"RouteToNAT1z213x95x138x236y32":{"Type":"AWS::EC2::Route","Properties":{"DestinationCidrBlock":"213.95.138.236/32","NatGatewayId":{"Ref":"NATGateway1"},"RouteTableId":{"Ref":"AZ1RouteTableIDParameter"}}}},"Outputs":{"EIP1":{"Description":"external IP of the NATGateway1","Value":{"Ref":"EIP1"}}}}`
template := p.generateTemplate(
destinationCidrBlocks,
[]string{"AZ1RouteTableIDParameter"},
map[string]int{"AZ1RouteTableIDParameter": 0},
)
if template != expect {
t.Errorf("Expect:\n %s,\n but got %s", expect, template)
t.Errorf("Expect:\n %s,\n but got:\n %s", expect, template)
}

}
Expand Down Expand Up @@ -370,7 +378,13 @@ func TestEnsure(tt *testing.T) {
describeRouteTables: &ec2.DescribeRouteTablesOutput{
RouteTables: []*ec2.RouteTable{
{
RouteTableId: aws.String(""),
RouteTableId: aws.String("foo"),
Tags: []*ec2.Tag{
{
Key: aws.String("AvailabilityZone"),
Value: aws.String("eu-central-1a"),
},
},
},
},
},
Expand Down Expand Up @@ -432,7 +446,11 @@ func TestEnsure(tt *testing.T) {
describeRouteTables: &ec2.DescribeRouteTablesOutput{
RouteTables: []*ec2.RouteTable{
{
RouteTableId: aws.String(""),
RouteTableId: aws.String("foo"),
Tags: []*ec2.Tag{{
Key: aws.String("AvailabilityZone"),
Value: aws.String("eu-central-1a"),
}},
},
},
},
Expand Down

0 comments on commit 0a658d3

Please sign in to comment.