diff --git a/aws/alb/updatealb.go b/aws/alb/updatealb.go index db02409..2552f9d 100644 --- a/aws/alb/updatealb.go +++ b/aws/alb/updatealb.go @@ -4,7 +4,8 @@ import ( "context" "errors" "fmt" - "log" + "sort" + "strconv" "defang.io/cloudacme/aws" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" @@ -14,7 +15,12 @@ import ( var ErrRuleNotFound = errors.New("rule not found") -func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error { +type RuleCondition struct { + PathPattern []string + HostHeader []string +} + +func DeleteListenerPathRule(ctx context.Context, listenerArn string, target RuleCondition) error { svc := elbv2.NewFromConfig(aws.LoadConfig()) searchInput := &elbv2.DescribeRulesInput{ ListenerArn: &listenerArn, @@ -25,14 +31,24 @@ func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error } ruleArn := "" +rules: for _, rule := range rulesOutput.Rules { - log.Printf("Condition: %+v", rule.Conditions) - if len(rule.Conditions) > 0 && rule.Conditions[0].PathPatternConfig != nil && rule.Conditions[0].PathPatternConfig.Values[0] == path { - log.Printf("Rule values %+v", rule.Conditions[0].PathPatternConfig.Values[0]) + for _, cond := range rule.Conditions { + if cond.PathPatternConfig != nil && target.PathPattern != nil && sameStringSlicesUnordered(cond.PathPatternConfig.Values, target.PathPattern) { + continue rules + } + if cond.HostHeaderConfig != nil && target.HostHeader != nil && sameStringSlicesUnordered(cond.HostHeaderConfig.Values, target.HostHeader) { + continue rules + } + // Only path and host header conditions are supported for now + if cond.SourceIpConfig != nil || cond.QueryStringConfig != nil || cond.HttpHeaderConfig != nil || cond.HttpRequestMethodConfig != nil { + continue rules + } ruleArn = *rule.RuleArn - break + break rules } } + if ruleArn == "" { return ErrRuleNotFound } @@ -47,8 +63,14 @@ func DeleteListenerPathRule(ctx context.Context, listenerArn, path string) error return nil } -func AddListenerStaticRule(ctx context.Context, listenerArn, path, value string, priority int32) error { +func AddListenerStaticRule(ctx context.Context, listenerArn string, ruleCond RuleCondition, value string) error { svc := elbv2.NewFromConfig(aws.LoadConfig()) + + priority, err := GetNextAvailablePriority(ctx, listenerArn) + if err != nil { + return err + } + input := &elbv2.CreateRuleInput{ Actions: []types.Action{ { @@ -62,23 +84,76 @@ func AddListenerStaticRule(ctx context.Context, listenerArn, path, value string, }, Conditions: []types.RuleCondition{ { - Field: ptr.String("path-pattern"), - PathPatternConfig: &types.PathPatternConditionConfig{ - Values: []string{path}, - }, + Field: ptr.String("path-pattern"), + PathPatternConfig: &types.PathPatternConditionConfig{Values: ruleCond.PathPattern}, + }, + { + Field: ptr.String("host-header"), + HostHeaderConfig: &types.HostHeaderConditionConfig{Values: ruleCond.HostHeader}, }, }, ListenerArn: &listenerArn, Priority: ptr.Int32(priority), } - _, err := svc.CreateRule(ctx, input) - if err != nil { + if _, err := svc.CreateRule(ctx, input); err != nil { return err } return nil } +func GetNextAvailablePriority(ctx context.Context, listenerArn string) (int32, error) { + rules, err := GetAllRules(ctx, listenerArn) + if err != nil { + return 0, err + } + + ps := make([]int, 0, len(rules)) + for _, rule := range rules { + if rule.Priority == nil { + continue + } + p, err := strconv.Atoi(*rule.Priority) + if err != nil { + continue + } + ps = append(ps, p) + } + ps = sort.IntSlice(ps) + priority := 1 + for _, p := range ps { + if priority == p { + priority++ + } else { + break + } + } + return int32(priority), nil +} + +func GetAllRules(ctx context.Context, listenerArn string) ([]types.Rule, error) { + svc := elbv2.NewFromConfig(aws.LoadConfig()) + + var rules []types.Rule + for { + searchInput := &elbv2.DescribeRulesInput{ + ListenerArn: &listenerArn, + PageSize: ptr.Int32(400), + } + + searchOuputput, err := svc.DescribeRules(ctx, searchInput) + if err != nil { + return nil, err + } + rules = append(rules, searchOuputput.Rules...) + + if searchOuputput.NextMarker == nil { + return rules, nil + } + searchInput.Marker = searchOuputput.NextMarker + } +} + func GetListener(ctx context.Context, albArn string, protocol types.ProtocolEnum, port int32) (*types.Listener, error) { svc := elbv2.NewFromConfig(aws.LoadConfig()) input := &elbv2.DescribeListenersInput{ @@ -144,3 +219,22 @@ func GetTargetGroupAlb(ctx context.Context, targetGroupArn string) (string, erro return tg.LoadBalancerArns[0], nil // Only 1 LB per tg possible according to aws docs } + +func sameStringSlicesUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + diff := make(map[string]int) + for _, s := range a { + diff[s]++ + } + for _, s := range b { + diff[s]-- + } + for _, v := range diff { + if v != 0 { + return false + } + } + return true +} diff --git a/cmd/inspect/main.go b/cmd/inspect/main.go new file mode 100644 index 0000000..b373238 --- /dev/null +++ b/cmd/inspect/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "fmt" + + "defang.io/cloudacme/aws" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" +) + +var listenerArn = "arn:aws:elasticloadbalancing:us-west-2:381492210770:listener/app/Defang-dayifu2-beta-alb/6f1d3e4bf5cbac4b/84ccc1071870455c" +var path = "/" + +func main() { + ctx := context.Background() + + svc := elbv2.NewFromConfig(aws.LoadConfig()) + searchInput := &elbv2.DescribeRulesInput{ + ListenerArn: &listenerArn, + } + rulesOutput, err := svc.DescribeRules(ctx, searchInput) + if err != nil { + panic(err) + } + + for _, rule := range rulesOutput.Rules { + fmt.Printf("RuleArn: %v\n", *rule.RuleArn) + for _, condition := range rule.Conditions { + fmt.Printf("Condition Type: %v\n", *condition.Field) + if condition.PathPatternConfig != nil { + fmt.Printf("\tPathPatternConfig: %v\n", condition.PathPatternConfig.Values) + } + if condition.HostHeaderConfig != nil { + fmt.Printf("\tHostHeaderConfig: %v\n", condition.HostHeaderConfig.Values) + } + if condition.HttpHeaderConfig != nil { + fmt.Printf("\tHttpHeaderConfig: %v\n", *condition.HttpHeaderConfig.HttpHeaderName) + } + if condition.HttpRequestMethodConfig != nil { + fmt.Printf("\tHttpRequestMethodConfig: %v\n", condition.HttpRequestMethodConfig.Values) + } + if condition.QueryStringConfig != nil { + fmt.Printf("\tQueryStringConfig: %v\n", condition.QueryStringConfig.Values) + } + if condition.SourceIpConfig != nil { + fmt.Printf("\tSourceIpConfig: %v\n", condition.SourceIpConfig.Values) + } + fmt.Printf("Values: %v\n", condition.Values) + } + } +} diff --git a/cmd/lambda/main.go b/cmd/lambda/main.go index a094c12..4d4a234 100644 --- a/cmd/lambda/main.go +++ b/cmd/lambda/main.go @@ -58,11 +58,17 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve return nil, fmt.Errorf("failed to get ALB ARN from target group %v: %w", targetGroupArn, err) } - if err := updateAcmeCertificate(ctx, albArn, evt.Headers["host"]); err != nil { + host := evt.Headers["host"] + if err := updateAcmeCertificate(ctx, albArn, host); err != nil { return nil, fmt.Errorf("failed to update certificate: %w", err) } - if err := removeHttpRule(ctx, albArn, "/"); err != nil { + cond := alb.RuleCondition{ + HostHeader: []string{host}, + PathPattern: []string{"/"}, + } + + if err := removeHttpRule(ctx, albArn, cond); err != nil { return nil, fmt.Errorf("failed to remove http rule: %w", err) } @@ -74,12 +80,12 @@ func HandleALBEvent(ctx context.Context, evt events.ALBTargetGroupRequest) (*eve }, nil } -func removeHttpRule(ctx context.Context, albArn, path string) error { +func removeHttpRule(ctx context.Context, albArn string, ruleCond alb.RuleCondition) error { listener, err := alb.GetListener(ctx, albArn, awsalb.ProtocolEnumHttp, 80) if err != nil { return fmt.Errorf("cannot get http listener: %w", err) } - if err := alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, path); err != nil { + if err := alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, ruleCond); err != nil { return fmt.Errorf("failed to delete listener static rule: %w", err) } return nil diff --git a/solver/albhttp01solver.go b/solver/albhttp01solver.go index c97b7a5..5e0ab9c 100644 --- a/solver/albhttp01solver.go +++ b/solver/albhttp01solver.go @@ -14,7 +14,7 @@ import ( "go.uber.org/zap" ) -const DefaultWaitTimeout = 1 * time.Minute +const DefaultWaitTimeout = 5 * time.Minute type AlbHttp01Solver struct { AlbArn string @@ -32,7 +32,12 @@ func (s AlbHttp01Solver) Present(ctx context.Context, chal acme.Challenge) error return fmt.Errorf("cannot get http listener: %w", err) } - if err := alb.AddListenerStaticRule(ctx, *listener.ListenerArn, chal.HTTP01ResourcePath(), chal.KeyAuthorization, 1); err != nil { + ruleCond := alb.RuleCondition{ + HostHeader: s.Domains, + PathPattern: []string{chal.HTTP01ResourcePath()}, + } + + if err := alb.AddListenerStaticRule(ctx, *listener.ListenerArn, ruleCond, chal.KeyAuthorization); err != nil { return fmt.Errorf("failed to add listener static rule: %v", err) } return nil @@ -46,7 +51,13 @@ func (s AlbHttp01Solver) CleanUp(ctx context.Context, chal acme.Challenge) error if err != nil { return fmt.Errorf("cannot get http listener: %w", err) } - err = alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, chal.HTTP01ResourcePath()) + + ruleCond := alb.RuleCondition{ + HostHeader: s.Domains, + PathPattern: []string{chal.HTTP01ResourcePath()}, + } + + err = alb.DeleteListenerPathRule(ctx, *listener.ListenerArn, ruleCond) if errors.Is(err, alb.ErrRuleNotFound) { if s.Logger != nil { s.Logger.Info("Challenge rule not found, skipping cleanup alb rule", zap.String("path", chal.HTTP01ResourcePath()))