diff --git a/pkg/controllers/controllers.go b/pkg/controllers/controllers.go index 350a1cc7b484..2514ebee5aed 100644 --- a/pkg/controllers/controllers.go +++ b/pkg/controllers/controllers.go @@ -92,7 +92,7 @@ func NewControllers( ssminvalidation.NewController(ssmCache, amiProvider), status.NewController[*v1.EC2NodeClass](kubeClient, mgr.GetEventRecorderFor("karpenter"), status.EmitDeprecatedMetrics), opevents.NewController[*corev1.Node](kubeClient, clk), - controllersversion.NewController(versionProvider), + controllersversion.NewController(versionProvider, versionProvider.UpdateVersionWithValidation), } if options.FromContext(ctx).InterruptionQueue != "" { sqsapi := servicesqs.NewFromConfig(cfg) diff --git a/pkg/controllers/providers/version/controller.go b/pkg/controllers/providers/version/controller.go index 711d4c759c24..57fcaab31245 100644 --- a/pkg/controllers/providers/version/controller.go +++ b/pkg/controllers/providers/version/controller.go @@ -28,20 +28,24 @@ import ( "github.com/aws/karpenter-provider-aws/pkg/providers/version" ) +type UpdateVersion func(context.Context) error + type Controller struct { versionProvider *version.DefaultProvider + updateVersion UpdateVersion } -func NewController(versionProvider *version.DefaultProvider) *Controller { +func NewController(versionProvider *version.DefaultProvider, updateVersion UpdateVersion) *Controller { return &Controller{ versionProvider: versionProvider, + updateVersion: updateVersion, } } func (c *Controller) Reconcile(ctx context.Context) (reconcile.Result, error) { ctx = injection.WithControllerName(ctx, "providers.version") - if err := c.versionProvider.UpdateVersion(ctx); err != nil { + if err := c.updateVersion(ctx); err != nil { return reconcile.Result{}, fmt.Errorf("updating version, %w", err) } return reconcile.Result{RequeueAfter: 5 * time.Minute}, nil diff --git a/pkg/controllers/providers/version/suite_test.go b/pkg/controllers/providers/version/suite_test.go index 59c0ddc62eae..005c45d6d8e4 100644 --- a/pkg/controllers/providers/version/suite_test.go +++ b/pkg/controllers/providers/version/suite_test.go @@ -56,7 +56,7 @@ var _ = BeforeSuite(func() { ctx = options.ToContext(ctx, test.Options()) ctx, stop = context.WithCancel(ctx) awsEnv = test.NewEnvironment(ctx, env) - controller = controllersversion.NewController(awsEnv.VersionProvider) + controller = controllersversion.NewController(awsEnv.VersionProvider, awsEnv.VersionProvider.UpdateVersionWithValidation) }) var _ = AfterSuite(func() { diff --git a/pkg/providers/version/suite_test.go b/pkg/providers/version/suite_test.go index 0536e706ece5..29910f8e04e0 100644 --- a/pkg/providers/version/suite_test.go +++ b/pkg/providers/version/suite_test.go @@ -57,7 +57,7 @@ var _ = BeforeSuite(func() { ctx, stop = context.WithCancel(ctx) awsEnv = test.NewEnvironment(ctx, env) testEnv = &environmentaws.Environment{Environment: &common.Environment{KubeClient: env.KubernetesInterface}} - versionController = controllersversion.NewController(awsEnv.VersionProvider) + versionController = controllersversion.NewController(awsEnv.VersionProvider, awsEnv.VersionProvider.UpdateVersionWithValidation) }) var _ = AfterSuite(func() { diff --git a/pkg/providers/version/version.go b/pkg/providers/version/version.go index 9dbeaae0520d..eb0101e43446 100644 --- a/pkg/providers/version/version.go +++ b/pkg/providers/version/version.go @@ -69,7 +69,7 @@ func (p *DefaultProvider) Get(ctx context.Context) string { } func (p *DefaultProvider) UpdateVersion(ctx context.Context) error { - var version, versionSource string + var version string var err error if options.FromContext(ctx).EKSControlPlane { @@ -84,7 +84,15 @@ func (p *DefaultProvider) UpdateVersion(ctx context.Context) error { } } p.version.Store(&version) - if p.cm.HasChanged("kubernetes-version", version) || p.cm.HasChanged("version-source", versionSource) { + return nil +} +func (p *DefaultProvider) UpdateVersionWithValidation(ctx context.Context) error { + err := p.UpdateVersion(ctx) + if err != nil { + return err + } + version := p.Get(ctx) + if p.cm.HasChanged("kubernetes-version", version) { log.FromContext(ctx).WithValues("version", version).V(1).Info("discovered kubernetes version") if err := validateK8sVersion(version); err != nil { return fmt.Errorf("validating kubernetes version, %w", err)