Skip to content

Commit

Permalink
Use more specific rule condition for adding and removal
Browse files Browse the repository at this point in the history
  • Loading branch information
edw-defang committed Apr 2, 2024
1 parent 96bca78 commit e481361
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 20 deletions.
117 changes: 104 additions & 13 deletions aws/alb/updatealb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import (
"context"
"errors"
"fmt"
"log"
"sort"
"strconv"

"defang.io/cloudacme/aws"
elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
Expand All @@ -14,7 +15,12 @@ import (

var ErrRuleNotFound = errors.New("rule not found")

func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error {
type RuleCondition struct {
PathPattern []string
HostHeader []string
}

func DeleteListenerPathRule(ctx context.Context, listenerArn string, target RuleCondition) error {
svc := elbv2.NewFromConfig(aws.LoadConfig())
searchInput := &elbv2.DescribeRulesInput{
ListenerArn: &listenerArn,
Expand All @@ -25,14 +31,24 @@ func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error
}

ruleArn := ""
rules:
for _, rule := range rulesOutput.Rules {
log.Printf("Condition: %+v", rule.Conditions)
if len(rule.Conditions) > 0 && rule.Conditions[0].PathPatternConfig != nil && rule.Conditions[0].PathPatternConfig.Values[0] == path {
log.Printf("Rule values %+v", rule.Conditions[0].PathPatternConfig.Values[0])
for _, cond := range rule.Conditions {
if cond.PathPatternConfig != nil && target.PathPattern != nil && sameStringSlicesUnordered(cond.PathPatternConfig.Values, target.PathPattern) {
continue rules
}
if cond.HostHeaderConfig != nil && target.HostHeader != nil && sameStringSlicesUnordered(cond.HostHeaderConfig.Values, target.HostHeader) {
continue rules
}
// Only path and host header conditions are supported for now
if cond.SourceIpConfig != nil || cond.QueryStringConfig != nil || cond.HttpHeaderConfig != nil || cond.HttpRequestMethodConfig != nil {
continue rules
}
ruleArn = *rule.RuleArn
break
break rules
}
}

if ruleArn == "" {
return ErrRuleNotFound
}
Expand All @@ -47,8 +63,14 @@ func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error
return nil
}

func AddListenerStaticRule(ctx context.Context, listenerArn, path, value string, priority int32) error {
func AddListenerStaticRule(ctx context.Context, listenerArn string, ruleCond RuleCondition, value string) error {
svc := elbv2.NewFromConfig(aws.LoadConfig())

priority, err := GetNextAvailablePriority(ctx, listenerArn)
if err != nil {
return err
}

input := &elbv2.CreateRuleInput{
Actions: []types.Action{
{
Expand All @@ -62,23 +84,73 @@ func AddListenerStaticRule(ctx context.Context, listenerArn, path, value string,
},
Conditions: []types.RuleCondition{
{
Field: ptr.String("path-pattern"),
PathPatternConfig: &types.PathPatternConditionConfig{
Values: []string{path},
},
Field: ptr.String("path-pattern"),
PathPatternConfig: &types.PathPatternConditionConfig{Values: ruleCond.PathPattern},
HostHeaderConfig: &types.HostHeaderConditionConfig{Values: ruleCond.HostHeader},
},
},
ListenerArn: &listenerArn,
Priority: ptr.Int32(priority),
}

_, err := svc.CreateRule(ctx, input)
if err != nil {
if _, err := svc.CreateRule(ctx, input); err != nil {
return err
}
return nil
}

func GetNextAvailablePriority(ctx context.Context, listenerArn string) (int32, error) {
rules, err := GetAllRules(ctx, listenerArn)
if err != nil {
return 0, err
}

ps := make([]int, 0, len(rules))
for _, rule := range rules {
if rule.Priority == nil {
continue
}
p, err := strconv.Atoi(*rule.Priority)
if err != nil {
continue
}
ps = append(ps, p)
}
ps = sort.IntSlice(ps)
priority := 1
for _, p := range ps {
if priority == p {
priority++
} else {
break
}
}
return int32(priority), nil
}

func GetAllRules(ctx context.Context, listenerArn string) ([]types.Rule, error) {
svc := elbv2.NewFromConfig(aws.LoadConfig())

var rules []types.Rule
for {
searchInput := &elbv2.DescribeRulesInput{
ListenerArn: &listenerArn,
PageSize: ptr.Int32(400),
}

searchOuputput, err := svc.DescribeRules(ctx, searchInput)
if err != nil {
return nil, err
}
rules = append(rules, searchOuputput.Rules...)

if searchOuputput.NextMarker == nil {
return rules, nil
}
searchInput.Marker = searchOuputput.NextMarker
}
}

func GetListener(ctx context.Context, albArn string, protocol types.ProtocolEnum, port int32) (*types.Listener, error) {
svc := elbv2.NewFromConfig(aws.LoadConfig())
input := &elbv2.DescribeListenersInput{
Expand Down Expand Up @@ -144,3 +216,22 @@ func GetTargetGroupAlb(ctx context.Context, targetGroupArn string) (string, erro

return tg.LoadBalancerArns[0], nil // Only 1 LB per tg possible according to aws docs
}

func sameStringSlicesUnordered(a, b []string) bool {
if len(a) != len(b) {
return false
}
diff := make(map[string]int)
for _, s := range a {
diff[s]++
}
for _, s := range b {
diff[s]--
}
for _, v := range diff {
if v != 0 {
return false
}
}
return true
}
51 changes: 51 additions & 0 deletions cmd/inspect/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package main

import (
"context"
"fmt"

"defang.io/cloudacme/aws"
elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
)

var listenerArn = "arn:aws:elasticloadbalancing:us-west-2:381492210770:listener/app/Defang-dayifu2-beta-alb/6f1d3e4bf5cbac4b/84ccc1071870455c"
var path = "/"

func main() {
ctx := context.Background()

svc := elbv2.NewFromConfig(aws.LoadConfig())
searchInput := &elbv2.DescribeRulesInput{
ListenerArn: &listenerArn,
}
rulesOutput, err := svc.DescribeRules(ctx, searchInput)
if err != nil {
panic(err)
}

for _, rule := range rulesOutput.Rules {
fmt.Printf("RuleArn: %v\n", *rule.RuleArn)
for _, condition := range rule.Conditions {
fmt.Printf("Condition Type: %v\n", *condition.Field)
if condition.PathPatternConfig != nil {
fmt.Printf("\tPathPatternConfig: %v\n", condition.PathPatternConfig.Values)
}
if condition.HostHeaderConfig != nil {
fmt.Printf("\tHostHeaderConfig: %v\n", condition.HostHeaderConfig.Values)
}
if condition.HttpHeaderConfig != nil {
fmt.Printf("\tHttpHeaderConfig: %v\n", *condition.HttpHeaderConfig.HttpHeaderName)
}
if condition.HttpRequestMethodConfig != nil {
fmt.Printf("\tHttpRequestMethodConfig: %v\n", condition.HttpRequestMethodConfig.Values)
}
if condition.QueryStringConfig != nil {
fmt.Printf("\tQueryStringConfig: %v\n", condition.QueryStringConfig.Values)
}
if condition.SourceIpConfig != nil {
fmt.Printf("\tSourceIpConfig: %v\n", condition.SourceIpConfig.Values)
}
fmt.Printf("Values: %v\n", condition.Values)
}
}
}
14 changes: 10 additions & 4 deletions cmd/lambda/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,17 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve
return nil, fmt.Errorf("failed to get ALB ARN from target group %v: %w", targetGroupArn, err)
}

if err := updateAcmeCertificate(ctx, albArn, evt.Headers["host"]); err != nil {
host := evt.Headers["host"]
if err := updateAcmeCertificate(ctx, albArn, host); err != nil {
return nil, fmt.Errorf("failed to update certificate: %w", err)
}

if err := removeHttpRule(ctx, albArn, "/"); err != nil {
cond := alb.RuleCondition{
HostHeader: []string{host},
PathPattern: []string{"/"},
}

if err := removeHttpRule(ctx, albArn, cond); err != nil {
return nil, fmt.Errorf("failed to remove http rule: %w", err)
}

Expand All @@ -74,12 +80,12 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve
}, nil
}

func removeHttpRule(ctx context.Context, albArn, path string) error {
func removeHttpRule(ctx context.Context, albArn string, ruleCond alb.RuleCondition) error {
listener, err := alb.GetListener(ctx, albArn, awsalb.ProtocolEnumHttp, 80)
if err != nil {
return fmt.Errorf("cannot get http listener: %w", err)
}
if err := alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, path); err != nil {
if err := alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, ruleCond); err != nil {
return fmt.Errorf("failed to delete listener static rule: %w", err)
}
return nil
Expand Down
17 changes: 14 additions & 3 deletions solver/albhttp01solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"go.uber.org/zap"
)

const DefaultWaitTimeout = 1 * time.Minute
const DefaultWaitTimeout = 5 * time.Minute

type AlbHttp01Solver struct {
AlbArn string
Expand All @@ -32,7 +32,12 @@ func (s AlbHttp01Solver) Present(ctx context.Context, chal acme.Challenge) error
return fmt.Errorf("cannot get http listener: %w", err)
}

if err := alb.AddListenerStaticRule(ctx, *listener.ListenerArn, chal.HTTP01ResourcePath(), chal.KeyAuthorization, 1); err != nil {
ruleCond := alb.RuleCondition{
HostHeader: s.Domains,
PathPattern: []string{chal.HTTP01ResourcePath()},
}

if err := alb.AddListenerStaticRule(ctx, *listener.ListenerArn, ruleCond, chal.KeyAuthorization); err != nil {
return fmt.Errorf("failed to add listener static rule: %v", err)
}
return nil
Expand All @@ -46,7 +51,13 @@ func (s AlbHttp01Solver) CleanUp(ctx context.Context, chal acme.Challenge) error
if err != nil {
return fmt.Errorf("cannot get http listener: %w", err)
}
err = alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, chal.HTTP01ResourcePath())

ruleCond := alb.RuleCondition{
HostHeader: s.Domains,
PathPattern: []string{chal.HTTP01ResourcePath()},
}

err = alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, ruleCond)
if errors.Is(err, alb.ErrRuleNotFound) {
if s.Logger != nil {
s.Logger.Info("Challenge rule not found, skipping cleanup alb rule", zap.String("path", chal.HTTP01ResourcePath()))
Expand Down

0 comments on commit e481361

Please sign in to comment.