diff --git a/pkg/filewatcher/filewatcher.go b/pkg/filewatcher/filewatcher.go index 52a8335ee0..2fc5a00efa 100644 --- a/pkg/filewatcher/filewatcher.go +++ b/pkg/filewatcher/filewatcher.go @@ -25,6 +25,11 @@ import ( "k8s.io/klog/v2" ) +// exit is a separate function to handle program termination +var exit = func(code int) { + os.Exit(code) +} + var watchCertificateFileOnce sync.Once // WatchFileForChanges watches the file, fileToWatch, for changes. If the file contents have changed, the pod this @@ -64,7 +69,7 @@ func checkForFileChanges(path string) error { case event, ok := <-watcher.Events: if ok && (event.Has(fsnotify.Write) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove)) { klog.V(2).Infof("file, %s, was modified, exiting...", event.Name) - os.Exit(0) + exit(0) } case err, ok := <-watcher.Errors: if ok { diff --git a/pkg/filewatcher/filewatcher_test.go b/pkg/filewatcher/filewatcher_test.go new file mode 100644 index 0000000000..303e7dad57 --- /dev/null +++ b/pkg/filewatcher/filewatcher_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filewatcher + +import ( + "os" + "testing" + "time" +) + +func TestWatchFileForChanges(t *testing.T) { + // Capture stdout + old := os.Stdout + _, w, _ := os.Pipe() + os.Stdout = w + + // Replace exit function with mock function + var exitCode int + exit = func(code int) { + exitCode = code + } + + // Create a temporary file to watch + tmpfile, err := os.CreateTemp("", "testfile") + if err != nil { + t.Fatal(err) + } + + // Test the WatchFileForChanges function + err = WatchFileForChanges(tmpfile.Name()) + if err != nil { + t.Errorf("Failed to watch file: %v", err) + } + + // Simulate a file change + err = os.WriteFile(tmpfile.Name(), []byte("new content"), 0644) + if err != nil { + t.Fatal(err) + } + + if exitCode != 0 { + t.Errorf("Expected exit code 0, but got %d", exitCode) + } + + os.Remove(tmpfile.Name()) + + time.Sleep(1 * time.Second) + + // Restore stdout + w.Close() + os.Stdout = old + exit = func(code int) { + os.Exit(code) + } +}