From febd9cfb6545727a18b22c281857388906930175 Mon Sep 17 00:00:00 2001 From: Zijun Wang Date: Sat, 1 Jun 2024 12:58:22 -0700 Subject: [PATCH] Check diff --- pkg/deploy/lattice/listener_manager.go | 78 ++++++++-- pkg/deploy/lattice/listener_synthesizer.go | 4 +- .../lattice/listener_synthesizer_test.go | 140 ++++++++++++++++++ pkg/deploy/lattice/targets_synthesizer.go | 9 +- pkg/gateway/model_build_listener.go | 1 - pkg/gateway/model_build_rule.go | 3 +- pkg/model/core/tlsroute.go | 3 +- pkg/model/core/tlsroute_test.go | 2 +- pkg/model/lattice/targetgroup.go | 4 + 9 files changed, 222 insertions(+), 22 deletions(-) diff --git a/pkg/deploy/lattice/listener_manager.go b/pkg/deploy/lattice/listener_manager.go index d81ffc50..29536459 100644 --- a/pkg/deploy/lattice/listener_manager.go +++ b/pkg/deploy/lattice/listener_manager.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" @@ -53,30 +54,46 @@ func (d *defaultListenerManager) Upsert( } d.log.Infof("Upsert listener %s-%s", modelListener.Spec.K8SRouteName, modelListener.Spec.K8SRouteNamespace) + latticeSvcId := modelSvc.Status.Id + latticeListenerSummary, err := d.findListenerByPort(ctx, latticeSvcId, modelListener.Spec.Port) + if err != nil { + return model.ListenerStatus{}, err + } - latticeListener, err := d.findListenerByPort(ctx, modelSvc.Status.Id, modelListener.Spec.Port) + if latticeListenerSummary == nil { + // listener not found, create new one + return d.create(ctx, latticeSvcId, modelListener, defaultAction) + } + + // listener found, the only mutable fild for lattice is defaultAction, check if it needs to be updated + needToUpdateDefaultAction, err := d.needToUpdateDefaultAction(ctx, latticeSvcId, *latticeListenerSummary.Id, defaultAction) if err != nil { return model.ListenerStatus{}, err + } - if latticeListener != nil { - // we don't need to do vpclattice.UpdateListener(), if we can find a existing one. - // Since we can make sure the every time when calling the vpclattice.CreateListener() it must have the latest and correct defaultAction. - d.log.Debugf("Found existing listener %s, nothing to update", aws.StringValue(latticeListener.Id)) + if needToUpdateDefaultAction { + return d.update(ctx, latticeSvcId, latticeListenerSummary, defaultAction) + } else { + d.log.Debugf("Found existing listener %s, defaultAction is up to date, nothing to update", aws.StringValue(latticeListenerSummary.Id)) return model.ListenerStatus{ - Name: aws.StringValue(latticeListener.Name), - ListenerArn: aws.StringValue(latticeListener.Arn), - Id: aws.StringValue(latticeListener.Id), - ServiceId: modelSvc.Status.Id, + Name: aws.StringValue(latticeListenerSummary.Name), + ListenerArn: aws.StringValue(latticeListenerSummary.Arn), + Id: aws.StringValue(latticeListenerSummary.Id), + ServiceId: latticeSvcId, }, nil } +} + +func (d *defaultListenerManager) create(ctx context.Context, latticeSvcId string, modelListener *model.Listener, defaultAction *vpclattice.RuleAction) ( + model.ListenerStatus, error) { listenerInput := vpclattice.CreateListenerInput{ ClientToken: nil, DefaultAction: defaultAction, Name: aws.String(k8sLatticeListenerName(modelListener)), Port: aws.Int64(modelListener.Spec.Port), Protocol: aws.String(modelListener.Spec.Protocol), - ServiceIdentifier: aws.String(modelSvc.Status.Id), + ServiceIdentifier: aws.String(latticeSvcId), Tags: d.cloud.DefaultTags(), } @@ -91,7 +108,29 @@ func (d *defaultListenerManager) Upsert( Name: aws.StringValue(resp.Name), ListenerArn: aws.StringValue(resp.Arn), Id: aws.StringValue(resp.Id), - ServiceId: modelSvc.Status.Id, + ServiceId: latticeSvcId, + }, nil +} + +func (d *defaultListenerManager) update(ctx context.Context, latticeSvcId string, listener *vpclattice.ListenerSummary, defaultAction *vpclattice.RuleAction) ( + model.ListenerStatus, error) { + + d.log.Debugf("Updating listener %s default action", aws.StringValue(listener.Id)) + _, err := d.cloud.Lattice().UpdateListenerWithContext(ctx, &vpclattice.UpdateListenerInput{ + DefaultAction: defaultAction, + ListenerIdentifier: listener.Id, + ServiceIdentifier: aws.String(latticeSvcId), + }) + if err != nil { + return model.ListenerStatus{}, fmt.Errorf("failed to update lattice listener %s due to %s", aws.StringValue(listener.Id), err) + } + d.log.Infof("Success UpdateListenerDefaultAction %s", aws.StringValue(listener.Id)) + + return model.ListenerStatus{ + Name: aws.StringValue(listener.Name), + ListenerArn: aws.StringValue(listener.Arn), + Id: aws.StringValue(listener.Id), + ServiceId: latticeSvcId, }, nil } @@ -136,6 +175,23 @@ func (d *defaultListenerManager) List(ctx context.Context, serviceID string) ([] return sdkListeners, nil } +func (d *defaultListenerManager) needToUpdateDefaultAction( + ctx context.Context, + latticeSvcId string, + latticeListenerId string, + listenerDefaultActionFromStack *vpclattice.RuleAction) (bool, error) { + + resp, err := d.cloud.Lattice().GetListenerWithContext(ctx, &vpclattice.GetListenerInput{ + ServiceIdentifier: &latticeSvcId, + ListenerIdentifier: &latticeListenerId, + }) + if err != nil { + return false, err + } + + return reflect.DeepEqual(resp.DefaultAction, listenerDefaultActionFromStack), nil +} + func (d *defaultListenerManager) findListenerByPort( ctx context.Context, latticeSvcId string, diff --git a/pkg/deploy/lattice/listener_synthesizer.go b/pkg/deploy/lattice/listener_synthesizer.go index 3190b73f..e6528eb5 100644 --- a/pkg/deploy/lattice/listener_synthesizer.go +++ b/pkg/deploy/lattice/listener_synthesizer.go @@ -100,10 +100,10 @@ func (l *listenerSynthesizer) getLatticeListenerDefaultAction(ctx context.Contex }, nil } - // For TLS_PASSTHROUGH listener, we need to fill the stackRules[0].Spec.Action.TargetGroups to the lattice listener's defaultAction tgs + // For TLS_PASSTHROUGH listener, we need to fill the stackRules[0].Spec.Action.TargetGroups to the lattice listener's defaultAction ForwardAction TargetGroups , + // since the TLS_PASSTHROUGH listener only has the defaultAction and no extra listener rules var stackRules []*model.Rule _ = l.stack.ListResources(&stackRules) - // Fill the default action target groups for TLS_PASSTHROUGH listener, since TLS_PASSTHROUGH listener only has the defaultAction and no extra listener rules if err := l.tgManager.ResolveRuleTgIds(ctx, stackRules[0], l.stack); err != nil { return nil, fmt.Errorf("failed to resolve rule tg ids, err = %v", err) } diff --git a/pkg/deploy/lattice/listener_synthesizer_test.go b/pkg/deploy/lattice/listener_synthesizer_test.go index 5d1d4ea1..7362db83 100644 --- a/pkg/deploy/lattice/listener_synthesizer_test.go +++ b/pkg/deploy/lattice/listener_synthesizer_test.go @@ -92,3 +92,143 @@ func Test_SynthesizeListenerCreateWithReconcile(t *testing.T) { err := ls.Synthesize(ctx) assert.Nil(t, err) } + +//func Test_SynthesizeTlsPassthroughListenerCreateWithReconcile(t *testing.T) { +// c := gomock.NewController(t) +// defer c.Finish() +// ctx := context.TODO() +// mockListenerMgr := NewMockListenerManager(c) +// mockTargetGroupManager := NewMockTargetGroupManager(c) +// +// stack := core.NewDefaultStack(core.StackID{Name: "foo", Namespace: "bar"}) +// +// svc := &model.Service{ +// ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Service", "stack-svc-id"), +// Status: &model.ServiceStatus{Id: "svc-id"}, +// } +// assert.NoError(t, stack.AddResource(svc)) +// +// l := &model.Listener{ +// ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Listener", "l-id"), +// Spec: model.ListenerSpec{ +// StackServiceId: "stack-svc-id", +// Port: 80, +// Protocol: "TLS_PASSTHROUGH", +// }, +// } +// assert.NoError(t, stack.AddResource(l)) +// +// mockListenerMgr.EXPECT().Upsert(ctx, gomock.Any(), gomock.Any(), gomock.Any()).Return( +// model.ListenerStatus{Id: "new-listener-id"}, nil) +// +// mockListenerMgr.EXPECT().List(ctx, gomock.Any()).Return([]*vpclattice.ListenerSummary{ +// { +// Id: aws.String("to-delete-id"), +// Port: aws.Int64(443), // <-- makes this listener unique +// }, +// }, nil) +// +// mockListenerMgr.EXPECT().Delete(ctx, gomock.Any()).DoAndReturn( +// func(ctx context.Context, ml *model.Listener) error { +// assert.Equal(t, "to-delete-id", ml.Status.Id) +// assert.Equal(t, "svc-id", ml.Status.ServiceId) +// return nil +// }) +// +// ls := NewListenerSynthesizer(gwlog.FallbackLogger, mockListenerMgr, mockTargetGroupManager, stack) +// err := ls.Synthesize(ctx) +// assert.Nil(t, err) +//} + +func Test_listenerSynthesizer_getLatticeListenerDefaultAction(t *testing.T) { + + tests := []struct { + name string + listenerProtocol string + want *vpclattice.RuleAction + wantErr error + }{ + { + name: "HTTP protocol Listener has the 404 fixed response default action", + listenerProtocol: "HTTP", + want: &vpclattice.RuleAction{ + FixedResponse: &vpclattice.FixedResponseAction{ + StatusCode: aws.Int64(404), + }, + }, + wantErr: nil, + }, + { + name: "HTTPS protocol Listener has the 404 fixed response default action", + listenerProtocol: "HTTPS", + want: &vpclattice.RuleAction{ + FixedResponse: &vpclattice.FixedResponseAction{ + StatusCode: aws.Int64(404), + }, + }, + wantErr: nil, + }, + { + name: "TLS_PASSTHROUGH protocol Listener has the default action forward to the StackRules[0] target group", + listenerProtocol: "TLS_PASSTHROUGH", + want: &vpclattice.RuleAction{ + Forward: &vpclattice.ForwardAction{ + TargetGroups: []*vpclattice.WeightedTargetGroup{ + { + TargetGroupIdentifier: aws.String("tg-id-1"), + }, + }, + }, + }, + wantErr: nil, + }, + } + + c := gomock.NewController(t) + defer c.Finish() + mockListenerMgr := NewMockListenerManager(c) + mockTargetGroupManager := NewMockTargetGroupManager(c) + + stack := core.NewDefaultStack(core.StackID{Name: "foo", Namespace: "bar"}) + + stackRule := &model.Rule{ + ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Rule", "rule-id"), + Spec: model.RuleSpec{ + Action: model.RuleAction{ + TargetGroups: []*model.RuleTargetGroup{ + { + + SvcImportTG: &model.SvcImportTargetGroup{ + K8SClusterName: "cluster-name", + K8SServiceName: "svc-name", + K8SServiceNamespace: "ns", + VpcId: "vpc-id", + }, + }, + { + StackTargetGroupId: "stack-tg-id", + }, + { + StackTargetGroupId: model.InvalidBackendRefTgId, + }, + }, + }, + }, + } + assert.NoError(t, stack.AddResource(stackRule)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &listenerSynthesizer{ + log: gwlog.FallbackLogger, + listenerMgr: mockListenerMgr, + tgManager: mockTargetGroupManager, + stack: stack, + } + got, err := l.getLatticeListenerDefaultAction(context.TODO(), tt.listenerProtocol) + + assert.Equalf(t, tt.want, got, "getLatticeListenerDefaultAction() listenerProtocol: %v", tt.listenerProtocol) + assert.Equal(t, tt.wantErr, err) + }) + } +} diff --git a/pkg/deploy/lattice/targets_synthesizer.go b/pkg/deploy/lattice/targets_synthesizer.go index a686578a..ede71ec9 100644 --- a/pkg/deploy/lattice/targets_synthesizer.go +++ b/pkg/deploy/lattice/targets_synthesizer.go @@ -4,15 +4,16 @@ import ( "context" "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/vpclattice" + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/aws/aws-application-networking-k8s/pkg/model/core" model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice" "github.com/aws/aws-application-networking-k8s/pkg/utils" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" "github.com/aws/aws-application-networking-k8s/pkg/webhook" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/vpclattice" - corev1 "k8s.io/api/core/v1" - "sigs.k8s.io/controller-runtime/pkg/client" ) const ( diff --git a/pkg/gateway/model_build_listener.go b/pkg/gateway/model_build_listener.go index 9ba48ec8..95a1a1fa 100644 --- a/pkg/gateway/model_build_listener.go +++ b/pkg/gateway/model_build_listener.go @@ -46,7 +46,6 @@ func (t *latticeServiceModelBuildTask) extractListenerInfo( protocol := section.Protocol if section.TLS != nil && section.TLS.Mode != nil && *section.TLS.Mode == gwv1.TLSModePassthrough { t.log.Debugf("Found TLS passthrough section %v", section.TLS) - // if the k8s gw.listener section has tls.mode: Passthrough, the lattice service listener protocol should be TLS_PASSTHROUGH protocol = vpclattice.ListenerProtocolTlsPassthrough } return int64(listenerPort), string(protocol), nil diff --git a/pkg/gateway/model_build_rule.go b/pkg/gateway/model_build_rule.go index c82db00e..572a71c6 100644 --- a/pkg/gateway/model_build_rule.go +++ b/pkg/gateway/model_build_rule.go @@ -33,8 +33,7 @@ const ( func (t *latticeServiceModelBuildTask) buildRules(ctx context.Context, stackListenerId string) error { // note we only build rules for non-deleted routes t.log.Debugf("Processing %d rules", len(t.route.Spec().Rules())) - - if t.route.GroupKind().Kind == "TLSRoute" && len(t.route.Spec().Rules()) > 1 { + if t.route.GroupKind().Kind == core.TlsRouteKind && len(t.route.Spec().Rules()) > 1 { return errors.New("TLSRoute only supports 1 rule") } diff --git a/pkg/model/core/tlsroute.go b/pkg/model/core/tlsroute.go index 4734956b..35faa16b 100644 --- a/pkg/model/core/tlsroute.go +++ b/pkg/model/core/tlsroute.go @@ -15,6 +15,7 @@ import ( const ( TlsRouteType RouteType = "tls" + TlsRouteKind string = "TLSRoute" ) type TLSRoute struct { @@ -82,7 +83,7 @@ func (r *TLSRoute) Inner() *gwv1alpha2.TLSRoute { func (r *TLSRoute) GroupKind() metav1.GroupKind { return metav1.GroupKind{ Group: gwv1beta1.GroupName, - Kind: "TLSRoute", + Kind: TlsRouteKind, } } diff --git a/pkg/model/core/tlsroute_test.go b/pkg/model/core/tlsroute_test.go index 0b77f589..f3e6055d 100644 --- a/pkg/model/core/tlsroute_test.go +++ b/pkg/model/core/tlsroute_test.go @@ -132,7 +132,7 @@ func TestTLSRouteBackendRef_Equals(t *testing.T) { expectEqual: true, }, { - description: "Instances populatd with the same values are equal", + description: "es are equal", backendRef1: &TLSBackendRef{ r: gwv1alpha2.BackendRef{ Weight: weight1, diff --git a/pkg/model/lattice/targetgroup.go b/pkg/model/lattice/targetgroup.go index ce9bdf3f..9dcfa161 100644 --- a/pkg/model/lattice/targetgroup.go +++ b/pkg/model/lattice/targetgroup.go @@ -171,6 +171,10 @@ func (t *TargetGroupSpec) Validate() error { t.Protocol, t.VpcId, t.K8SClusterName, t.IpAddressType, string(t.K8SSourceType)} + if t.Protocol != "TCP" { + requiredFields = append(requiredFields, t.ProtocolVersion) + } + for _, s := range requiredFields { if s == "" { return errors.New("one or more required fields are missing")