Skip to content

Commit

Permalink
Check diff
Browse files Browse the repository at this point in the history
  • Loading branch information
Zijun Wang committed Jun 1, 2024
1 parent 9cc4efb commit febd9cf
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 22 deletions.
78 changes: 67 additions & 11 deletions pkg/deploy/lattice/listener_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"

"github.com/aws/aws-application-networking-k8s/pkg/aws/services"
"github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog"
Expand Down Expand Up @@ -53,30 +54,46 @@ func (d *defaultListenerManager) Upsert(
}

d.log.Infof("Upsert listener %s-%s", modelListener.Spec.K8SRouteName, modelListener.Spec.K8SRouteNamespace)
latticeSvcId := modelSvc.Status.Id
latticeListenerSummary, err := d.findListenerByPort(ctx, latticeSvcId, modelListener.Spec.Port)
if err != nil {
return model.ListenerStatus{}, err
}

latticeListener, err := d.findListenerByPort(ctx, modelSvc.Status.Id, modelListener.Spec.Port)
if latticeListenerSummary == nil {
// listener not found, create new one
return d.create(ctx, latticeSvcId, modelListener, defaultAction)
}

// listener found, the only mutable fild for lattice is defaultAction, check if it needs to be updated
needToUpdateDefaultAction, err := d.needToUpdateDefaultAction(ctx, latticeSvcId, *latticeListenerSummary.Id, defaultAction)
if err != nil {
return model.ListenerStatus{}, err

}
if latticeListener != nil {
// we don't need to do vpclattice.UpdateListener(), if we can find a existing one.
// Since we can make sure the every time when calling the vpclattice.CreateListener() it must have the latest and correct defaultAction.
d.log.Debugf("Found existing listener %s, nothing to update", aws.StringValue(latticeListener.Id))
if needToUpdateDefaultAction {
return d.update(ctx, latticeSvcId, latticeListenerSummary, defaultAction)
} else {
d.log.Debugf("Found existing listener %s, defaultAction is up to date, nothing to update", aws.StringValue(latticeListenerSummary.Id))
return model.ListenerStatus{
Name: aws.StringValue(latticeListener.Name),
ListenerArn: aws.StringValue(latticeListener.Arn),
Id: aws.StringValue(latticeListener.Id),
ServiceId: modelSvc.Status.Id,
Name: aws.StringValue(latticeListenerSummary.Name),
ListenerArn: aws.StringValue(latticeListenerSummary.Arn),
Id: aws.StringValue(latticeListenerSummary.Id),
ServiceId: latticeSvcId,
}, nil
}

}

func (d *defaultListenerManager) create(ctx context.Context, latticeSvcId string, modelListener *model.Listener, defaultAction *vpclattice.RuleAction) (
model.ListenerStatus, error) {
listenerInput := vpclattice.CreateListenerInput{
ClientToken: nil,
DefaultAction: defaultAction,
Name: aws.String(k8sLatticeListenerName(modelListener)),
Port: aws.Int64(modelListener.Spec.Port),
Protocol: aws.String(modelListener.Spec.Protocol),
ServiceIdentifier: aws.String(modelSvc.Status.Id),
ServiceIdentifier: aws.String(latticeSvcId),
Tags: d.cloud.DefaultTags(),
}

Expand All @@ -91,7 +108,29 @@ func (d *defaultListenerManager) Upsert(
Name: aws.StringValue(resp.Name),
ListenerArn: aws.StringValue(resp.Arn),
Id: aws.StringValue(resp.Id),
ServiceId: modelSvc.Status.Id,
ServiceId: latticeSvcId,
}, nil
}

func (d *defaultListenerManager) update(ctx context.Context, latticeSvcId string, listener *vpclattice.ListenerSummary, defaultAction *vpclattice.RuleAction) (
model.ListenerStatus, error) {

d.log.Debugf("Updating listener %s default action", aws.StringValue(listener.Id))
_, err := d.cloud.Lattice().UpdateListenerWithContext(ctx, &vpclattice.UpdateListenerInput{
DefaultAction: defaultAction,
ListenerIdentifier: listener.Id,
ServiceIdentifier: aws.String(latticeSvcId),
})
if err != nil {
return model.ListenerStatus{}, fmt.Errorf("failed to update lattice listener %s due to %s", aws.StringValue(listener.Id), err)
}
d.log.Infof("Success UpdateListenerDefaultAction %s", aws.StringValue(listener.Id))

return model.ListenerStatus{
Name: aws.StringValue(listener.Name),
ListenerArn: aws.StringValue(listener.Arn),
Id: aws.StringValue(listener.Id),
ServiceId: latticeSvcId,
}, nil
}

Expand Down Expand Up @@ -136,6 +175,23 @@ func (d *defaultListenerManager) List(ctx context.Context, serviceID string) ([]
return sdkListeners, nil
}

func (d *defaultListenerManager) needToUpdateDefaultAction(
ctx context.Context,
latticeSvcId string,
latticeListenerId string,
listenerDefaultActionFromStack *vpclattice.RuleAction) (bool, error) {

resp, err := d.cloud.Lattice().GetListenerWithContext(ctx, &vpclattice.GetListenerInput{
ServiceIdentifier: &latticeSvcId,
ListenerIdentifier: &latticeListenerId,
})
if err != nil {
return false, err
}

return reflect.DeepEqual(resp.DefaultAction, listenerDefaultActionFromStack), nil
}

func (d *defaultListenerManager) findListenerByPort(
ctx context.Context,
latticeSvcId string,
Expand Down
4 changes: 2 additions & 2 deletions pkg/deploy/lattice/listener_synthesizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ func (l *listenerSynthesizer) getLatticeListenerDefaultAction(ctx context.Contex
}, nil
}

// For TLS_PASSTHROUGH listener, we need to fill the stackRules[0].Spec.Action.TargetGroups to the lattice listener's defaultAction tgs
// For TLS_PASSTHROUGH listener, we need to fill the stackRules[0].Spec.Action.TargetGroups to the lattice listener's defaultAction ForwardAction TargetGroups ,
// since the TLS_PASSTHROUGH listener only has the defaultAction and no extra listener rules
var stackRules []*model.Rule
_ = l.stack.ListResources(&stackRules)
// Fill the default action target groups for TLS_PASSTHROUGH listener, since TLS_PASSTHROUGH listener only has the defaultAction and no extra listener rules
if err := l.tgManager.ResolveRuleTgIds(ctx, stackRules[0], l.stack); err != nil {
return nil, fmt.Errorf("failed to resolve rule tg ids, err = %v", err)
}
Expand Down
140 changes: 140 additions & 0 deletions pkg/deploy/lattice/listener_synthesizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,143 @@ func Test_SynthesizeListenerCreateWithReconcile(t *testing.T) {
err := ls.Synthesize(ctx)
assert.Nil(t, err)
}

//func Test_SynthesizeTlsPassthroughListenerCreateWithReconcile(t *testing.T) {
// c := gomock.NewController(t)
// defer c.Finish()
// ctx := context.TODO()
// mockListenerMgr := NewMockListenerManager(c)
// mockTargetGroupManager := NewMockTargetGroupManager(c)
//
// stack := core.NewDefaultStack(core.StackID{Name: "foo", Namespace: "bar"})
//
// svc := &model.Service{
// ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Service", "stack-svc-id"),
// Status: &model.ServiceStatus{Id: "svc-id"},
// }
// assert.NoError(t, stack.AddResource(svc))
//
// l := &model.Listener{
// ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Listener", "l-id"),
// Spec: model.ListenerSpec{
// StackServiceId: "stack-svc-id",
// Port: 80,
// Protocol: "TLS_PASSTHROUGH",
// },
// }
// assert.NoError(t, stack.AddResource(l))
//
// mockListenerMgr.EXPECT().Upsert(ctx, gomock.Any(), gomock.Any(), gomock.Any()).Return(
// model.ListenerStatus{Id: "new-listener-id"}, nil)
//
// mockListenerMgr.EXPECT().List(ctx, gomock.Any()).Return([]*vpclattice.ListenerSummary{
// {
// Id: aws.String("to-delete-id"),
// Port: aws.Int64(443), // <-- makes this listener unique
// },
// }, nil)
//
// mockListenerMgr.EXPECT().Delete(ctx, gomock.Any()).DoAndReturn(
// func(ctx context.Context, ml *model.Listener) error {
// assert.Equal(t, "to-delete-id", ml.Status.Id)
// assert.Equal(t, "svc-id", ml.Status.ServiceId)
// return nil
// })
//
// ls := NewListenerSynthesizer(gwlog.FallbackLogger, mockListenerMgr, mockTargetGroupManager, stack)
// err := ls.Synthesize(ctx)
// assert.Nil(t, err)
//}

func Test_listenerSynthesizer_getLatticeListenerDefaultAction(t *testing.T) {

tests := []struct {
name string
listenerProtocol string
want *vpclattice.RuleAction
wantErr error
}{
{
name: "HTTP protocol Listener has the 404 fixed response default action",
listenerProtocol: "HTTP",
want: &vpclattice.RuleAction{
FixedResponse: &vpclattice.FixedResponseAction{
StatusCode: aws.Int64(404),
},
},
wantErr: nil,
},
{
name: "HTTPS protocol Listener has the 404 fixed response default action",
listenerProtocol: "HTTPS",
want: &vpclattice.RuleAction{
FixedResponse: &vpclattice.FixedResponseAction{
StatusCode: aws.Int64(404),
},
},
wantErr: nil,
},
{
name: "TLS_PASSTHROUGH protocol Listener has the default action forward to the StackRules[0] target group",
listenerProtocol: "TLS_PASSTHROUGH",
want: &vpclattice.RuleAction{
Forward: &vpclattice.ForwardAction{
TargetGroups: []*vpclattice.WeightedTargetGroup{
{
TargetGroupIdentifier: aws.String("tg-id-1"),
},
},
},
},
wantErr: nil,
},
}

c := gomock.NewController(t)
defer c.Finish()
mockListenerMgr := NewMockListenerManager(c)
mockTargetGroupManager := NewMockTargetGroupManager(c)

stack := core.NewDefaultStack(core.StackID{Name: "foo", Namespace: "bar"})

stackRule := &model.Rule{
ResourceMeta: core.NewResourceMeta(stack, "AWS:VPCServiceNetwork::Rule", "rule-id"),
Spec: model.RuleSpec{
Action: model.RuleAction{
TargetGroups: []*model.RuleTargetGroup{
{

SvcImportTG: &model.SvcImportTargetGroup{
K8SClusterName: "cluster-name",
K8SServiceName: "svc-name",
K8SServiceNamespace: "ns",
VpcId: "vpc-id",
},
},
{
StackTargetGroupId: "stack-tg-id",
},
{
StackTargetGroupId: model.InvalidBackendRefTgId,
},
},
},
},
}
assert.NoError(t, stack.AddResource(stackRule))

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &listenerSynthesizer{
log: gwlog.FallbackLogger,
listenerMgr: mockListenerMgr,
tgManager: mockTargetGroupManager,
stack: stack,
}
got, err := l.getLatticeListenerDefaultAction(context.TODO(), tt.listenerProtocol)

assert.Equalf(t, tt.want, got, "getLatticeListenerDefaultAction() listenerProtocol: %v", tt.listenerProtocol)
assert.Equal(t, tt.wantErr, err)
})
}
}
9 changes: 5 additions & 4 deletions pkg/deploy/lattice/targets_synthesizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/vpclattice"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/aws/aws-application-networking-k8s/pkg/model/core"
model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice"
"github.com/aws/aws-application-networking-k8s/pkg/utils"
"github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog"
"github.com/aws/aws-application-networking-k8s/pkg/webhook"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/vpclattice"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
Expand Down
1 change: 0 additions & 1 deletion pkg/gateway/model_build_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ func (t *latticeServiceModelBuildTask) extractListenerInfo(
protocol := section.Protocol
if section.TLS != nil && section.TLS.Mode != nil && *section.TLS.Mode == gwv1.TLSModePassthrough {
t.log.Debugf("Found TLS passthrough section %v", section.TLS)
// if the k8s gw.listener section has tls.mode: Passthrough, the lattice service listener protocol should be TLS_PASSTHROUGH
protocol = vpclattice.ListenerProtocolTlsPassthrough
}
return int64(listenerPort), string(protocol), nil
Expand Down
3 changes: 1 addition & 2 deletions pkg/gateway/model_build_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ const (
func (t *latticeServiceModelBuildTask) buildRules(ctx context.Context, stackListenerId string) error {
// note we only build rules for non-deleted routes
t.log.Debugf("Processing %d rules", len(t.route.Spec().Rules()))

if t.route.GroupKind().Kind == "TLSRoute" && len(t.route.Spec().Rules()) > 1 {
if t.route.GroupKind().Kind == core.TlsRouteKind && len(t.route.Spec().Rules()) > 1 {
return errors.New("TLSRoute only supports 1 rule")
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/model/core/tlsroute.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

const (
TlsRouteType RouteType = "tls"
TlsRouteKind string = "TLSRoute"
)

type TLSRoute struct {
Expand Down Expand Up @@ -82,7 +83,7 @@ func (r *TLSRoute) Inner() *gwv1alpha2.TLSRoute {
func (r *TLSRoute) GroupKind() metav1.GroupKind {
return metav1.GroupKind{
Group: gwv1beta1.GroupName,
Kind: "TLSRoute",
Kind: TlsRouteKind,
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/model/core/tlsroute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestTLSRouteBackendRef_Equals(t *testing.T) {
expectEqual: true,
},
{
description: "Instances populatd with the same values are equal",
description: "es are equal",
backendRef1: &TLSBackendRef{
r: gwv1alpha2.BackendRef{
Weight: weight1,
Expand Down
4 changes: 4 additions & 0 deletions pkg/model/lattice/targetgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ func (t *TargetGroupSpec) Validate() error {
t.Protocol, t.VpcId, t.K8SClusterName, t.IpAddressType,
string(t.K8SSourceType)}

if t.Protocol != "TCP" {
requiredFields = append(requiredFields, t.ProtocolVersion)
}

for _, s := range requiredFields {
if s == "" {
return errors.New("one or more required fields are missing")
Expand Down

0 comments on commit febd9cf

Please sign in to comment.