From 2bc0da67003cbeb1cbdca8d252175a7f8abce727 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 4 Jul 2019 16:43:07 +0200 Subject: [PATCH] Fix clean shutdown with portmaster-control Also, only download portmaster-core on initial run of portmaster-control. Fixes #7 --- main.go | 2 +- pmctl/get.go | 63 ++++++++++++++++----------- pmctl/main.go | 23 +++++++++- pmctl/run.go | 104 ++++++++++++++++++++++++--------------------- pmctl/upgrade.go | 10 ++++- updates/updater.go | 78 +++++++++++++++++++++------------- 6 files changed, 174 insertions(+), 106 deletions(-) diff --git a/main.go b/main.go index aed019c75..1600e3635 100644 --- a/main.go +++ b/main.go @@ -31,7 +31,7 @@ func init() { func main() { // Set Info - info.Set("Portmaster", "0.3.0", "AGPLv3", true) + info.Set("Portmaster", "0.3.1", "AGPLv3", true) // Start err := modules.Start() diff --git a/pmctl/get.go b/pmctl/get.go index 6457ddba8..248e8f4f8 100644 --- a/pmctl/get.go +++ b/pmctl/get.go @@ -1,17 +1,18 @@ package main import ( + "errors" "fmt" - "os" + "time" - "github.com/safing/portbase/utils" "github.com/safing/portmaster/updates" ) -func getFile(identifier string) (*updates.File, error) { +func getFile(opts *Options) (*updates.File, error) { // get newest local file updates.LoadLatest() - file, err := updates.GetPlatformFile(identifier) + + file, err := updates.GetLocalPlatformFile(opts.Identifier) if err == nil { return file, nil } @@ -19,28 +20,42 @@ func getFile(identifier string) (*updates.File, error) { return nil, err } - fmt.Printf("%s downloading %s...\n", logPrefix, identifier) - - // if no matching file exists, load index - err = updates.LoadIndexes() - if err != nil { - if os.IsNotExist(err) { - // create dirs - err = utils.EnsureDirectory(updateStoragePath, 0755) - if err != nil { - return nil, err - } - - // download indexes - err = updates.CheckForUpdates() - if err != nil { - return nil, err - } - } else { + // download + if opts.AllowDownload { + fmt.Printf("%s downloading %s...\n", logPrefix, opts.Identifier) + + // download indexes + err = updates.UpdateIndexes() + if err != nil { return nil, err } + + // download file + file, err := updates.GetPlatformFile(opts.Identifier) + if err != nil { + return nil, err + } + return file, nil } - // get file - return updates.GetPlatformFile(identifier) + // wait for 30 seconds + fmt.Printf("%s waiting for download of %s (by Portmaster Core) to complete...\n", logPrefix, opts.Identifier) + + // try every 0.5 secs + for tries := 0; tries < 60; tries++ { + time.Sleep(500 * time.Millisecond) + + // reload local files + updates.LoadLatest() + + // get file + file, err := updates.GetLocalPlatformFile(opts.Identifier) + if err == nil { + return file, nil + } + if err != updates.ErrNotFound { + return nil, err + } + } + return nil, errors.New("please try again later or check the Portmaster logs") } diff --git a/pmctl/main.go b/pmctl/main.go index 320d2ccfe..c6f661095 100644 --- a/pmctl/main.go +++ b/pmctl/main.go @@ -5,7 +5,10 @@ import ( "flag" "fmt" "os" + "os/user" "path/filepath" + "runtime" + "strings" "github.com/safing/portbase/info" "github.com/safing/portbase/log" @@ -55,7 +58,7 @@ func main() { // }() // set meta info - info.Set("Portmaster Control", "0.2.0", "AGPLv3", true) + info.Set("Portmaster Control", "0.2.1", "AGPLv3", true) // check if meta info is ok err := info.CheckVersion() @@ -86,7 +89,23 @@ func initPmCtl(cmd *cobra.Command, args []string) error { return errors.New("please supply the database directory using the --db flag") } - err := removeOldBin() + // check if we are root/admin for self upgrade + userInfo, err := user.Current() + if err != nil { + return nil + } + switch runtime.GOOS { + case "linux": + if userInfo.Username != "root" { + return nil + } + case "windows": + if !strings.HasSuffix(userInfo.Username, "SYSTEM") { // is this correct? + return nil + } + } + + err = removeOldBin() if err != nil { fmt.Printf("%s warning: failed to remove old upgrade: %s\n", logPrefix, err) } diff --git a/pmctl/run.go b/pmctl/run.go index 978f47b1a..054d13ef4 100644 --- a/pmctl/run.go +++ b/pmctl/run.go @@ -5,12 +5,20 @@ import ( "io" "os" "os/exec" + "os/signal" "runtime" "strings" + "syscall" "github.com/spf13/cobra" ) +// Options for starting component +type Options struct { + Identifier string + AllowDownload bool +} + func init() { rootCmd.AddCommand(runCmd) runCmd.AddCommand(runCore) @@ -27,7 +35,10 @@ var runCore = &cobra.Command{ Use: "core", Short: "Run the Portmaster Core", RunE: func(cmd *cobra.Command, args []string) error { - return run("core/portmaster-core", cmd, false) + return run(cmd, &Options{ + Identifier: "core/portmaster-core", + AllowDownload: true, + }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags @@ -39,7 +50,10 @@ var runApp = &cobra.Command{ Use: "app", Short: "Run the Portmaster App", RunE: func(cmd *cobra.Command, args []string) error { - return run("app/portmaster-app", cmd, true) + return run(cmd, &Options{ + Identifier: "app/portmaster-app", + AllowDownload: false, + }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags @@ -51,7 +65,10 @@ var runNotifier = &cobra.Command{ Use: "notifier", Short: "Run the Portmaster Notifier", RunE: func(cmd *cobra.Command, args []string) error { - return run("notifier/portmaster-notifier", cmd, true) + return run(cmd, &Options{ + Identifier: "notifier/portmaster-notifier", + AllowDownload: false, + }) }, FParseErrWhitelist: cobra.FParseErrWhitelist{ // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags @@ -59,68 +76,42 @@ var runNotifier = &cobra.Command{ }, } -func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { +func run(cmd *cobra.Command, opts *Options) error { // get original arguments - if len(os.Args) <= 3 { - return cmd.Help() - } var args []string - - // filter out database flag - if filterDatabaseFlag { - skip := false - for _, arg := range os.Args[3:] { - if skip { - skip = false - continue - } - - if arg == "--db" { - // flag is seperated, skip two arguments - skip = true - continue - } - - if strings.HasPrefix(arg, "--db=") { - // flag is one string, skip one argument - continue - } - - args = append(args, arg) - } - } else { - args = os.Args[3:] + if len(os.Args) < 4 { + return cmd.Help() } + args = os.Args[3:] // adapt identifier if windows() { - identifier += ".exe" + opts.Identifier += ".exe" } // run for { - file, err := getFile(identifier) + file, err := getFile(opts) if err != nil { - return fmt.Errorf("%s could not get component: %s", logPrefix, err) + return fmt.Errorf("could not get component: %s", err) } // check permission if !windows() { info, err := os.Stat(file.Path()) if err != nil { - return fmt.Errorf("%s failed to get file info on %s: %s", logPrefix, file.Path(), err) + return fmt.Errorf("failed to get file info on %s: %s", file.Path(), err) } if info.Mode() != 0755 { err := os.Chmod(file.Path(), 0755) if err != nil { - return fmt.Errorf("%s failed to set exec permissions on %s: %s", logPrefix, file.Path(), err) + return fmt.Errorf("failed to set exec permissions on %s: %s", file.Path(), err) } } } fmt.Printf("%s starting %s %s\n", logPrefix, file.Path(), strings.Join(args, " ")) - // os.Exit(0) // create command exc := exec.Command(file.Path(), args...) @@ -128,17 +119,17 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { // consume stdout/stderr stdout, err := exc.StdoutPipe() if err != nil { - return fmt.Errorf("%s failed to connect stdout: %s", logPrefix, err) + return fmt.Errorf("failed to connect stdout: %s", err) } stderr, err := exc.StderrPipe() if err != nil { - return fmt.Errorf("%s failed to connect stderr: %s", logPrefix, err) + return fmt.Errorf("failed to connect stderr: %s", err) } // start err = exc.Start() if err != nil { - return fmt.Errorf("%s failed to start %s: %s", logPrefix, identifier, err) + return fmt.Errorf("failed to start %s: %s", opts.Identifier, err) } // start output writers @@ -149,6 +140,24 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { io.Copy(os.Stderr, stderr) }() + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + os.Kill, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) + go func() { + for { + sig := <-signalCh + fmt.Printf("%s got %s signal (ignoring), waiting for %s to exit...\n", logPrefix, sig, opts.Identifier) + } + }() + // wait for completion err = exc.Wait() if err != nil { @@ -157,31 +166,30 @@ func run(identifier string, cmd *cobra.Command, filterDatabaseFlag bool) error { switch exErr.ProcessState.ExitCode() { case 0: // clean exit - fmt.Printf("%s clean exit of %s, but with error: %s\n", logPrefix, identifier, err) + fmt.Printf("%s clean exit of %s, but with error: %s\n", logPrefix, opts.Identifier, err) os.Exit(1) case 1: // error exit - fmt.Printf("%s error during execution of %s: %s\n", logPrefix, identifier, err) + fmt.Printf("%s error during execution of %s: %s\n", logPrefix, opts.Identifier, err) os.Exit(1) case 2357427: // Leet Speak for "restart" // restart request - fmt.Printf("%s restarting %s\n", logPrefix, identifier) + fmt.Printf("%s restarting %s\n", logPrefix, opts.Identifier) continue default: - fmt.Printf("%s unexpected error during execution of %s: %s\n", logPrefix, identifier, err) + fmt.Printf("%s unexpected error during execution of %s: %s\n", logPrefix, opts.Identifier, err) os.Exit(exErr.ProcessState.ExitCode()) } } else { - fmt.Printf("%s unexpected error type during execution of %s: %s\n", logPrefix, identifier, err) + fmt.Printf("%s unexpected error type during execution of %s: %s\n", logPrefix, opts.Identifier, err) os.Exit(1) } } - // clean exit break } - fmt.Printf("%s %s completed successfully\n", logPrefix, identifier) + fmt.Printf("%s %s completed successfully\n", logPrefix, opts.Identifier) return nil } diff --git a/pmctl/upgrade.go b/pmctl/upgrade.go index c404ec057..142840a8a 100644 --- a/pmctl/upgrade.go +++ b/pmctl/upgrade.go @@ -11,6 +11,10 @@ import ( "github.com/safing/portmaster/updates" ) +var ( + oldBinSuffix = "-old" +) + func checkForUpgrade() (update *updates.File) { info := info.GetInfo() file, err := updates.GetLocalPlatformFile("control/portmaster-control") @@ -25,6 +29,8 @@ func checkForUpgrade() (update *updates.File) { func doSelfUpgrade(file *updates.File) error { + // FIXME: fix permissions if needed + // get destination dst, err := os.Executable() if err != nil { @@ -36,7 +42,7 @@ func doSelfUpgrade(file *updates.File) error { } // mv destination - err = os.Rename(dst, dst+"_old") + err = os.Rename(dst, dst+oldBinSuffix) if err != nil { return err } @@ -105,7 +111,7 @@ func removeOldBin() error { } // delete old - err = os.Remove(dst + "_old") + err = os.Remove(dst + oldBinSuffix) if err != nil { if !os.IsNotExist(err) { return err diff --git a/updates/updater.go b/updates/updater.go index d10984c95..2a01babe2 100644 --- a/updates/updater.go +++ b/updates/updater.go @@ -11,14 +11,19 @@ import ( "time" "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" ) func updater() { time.Sleep(10 * time.Second) for { - err := CheckForUpdates() + err := UpdateIndexes() if err != nil { - log.Warningf("updates: failed to check for updates: %s", err) + log.Warningf("updates: updating index failed: %s", err) + } + err = DownloadUpdates() + if err != nil { + log.Warningf("updates: downloading updates failed: %s", err) } time.Sleep(1 * time.Hour) } @@ -40,10 +45,9 @@ func markPlatformFileForDownload(identifier string) { markFileForDownload(identifier) } -// CheckForUpdates checks if updates are available and downloads updates of used components. -func CheckForUpdates() (err error) { - - // download new index +// UpdateIndexes downloads the current update indexes. +func UpdateIndexes() (err error) { + // download new indexes var data []byte for tries := 0; tries < 3; tries++ { data, err = fetchData("stable.json", tries) @@ -52,39 +56,72 @@ func CheckForUpdates() (err error) { } } if err != nil { - return err + return fmt.Errorf("failed to download: %s", err) } newStableUpdates := make(map[string]string) err = json.Unmarshal(data, &newStableUpdates) if err != nil { - return err + return fmt.Errorf("failed to parse: %s", err) } if len(newStableUpdates) == 0 { - return errors.New("stable.json is empty") + return errors.New("index is empty") } + // update stable index + updatesLock.Lock() + stableUpdates = newStableUpdates + updatesLock.Unlock() + + // check dir + err = utils.EnsureDirectory(updateStoragePath, 0755) + if err != nil { + return err + } + + // save stable index + err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644) + if err != nil { + log.Warningf("updates: failed to save new version of stable.json: %s", err) + } + + // update version status + updatesLock.RLock() + updateStatus(versionClassStable, stableUpdates) + updatesLock.RUnlock() + // FIXME IN STABLE: correct log line - log.Infof("updates: downloaded new update index: stable.json (alpha until we actually reach stable)") + log.Infof("updates: updated index stable.json (alpha/beta until we actually reach stable)") + + return nil +} + +// DownloadUpdates checks if updates are available and downloads updates of used components. +func DownloadUpdates() (err error) { // ensure important components are always updated updatesLock.Lock() if runtime.GOOS == "windows" { + markPlatformFileForDownload("core/portmaster-core.exe") markPlatformFileForDownload("control/portmaster-control.exe") markPlatformFileForDownload("app/portmaster-app.exe") markPlatformFileForDownload("notifier/portmaster-notifier.exe") } else { + markPlatformFileForDownload("core/portmaster-core") markPlatformFileForDownload("control/portmaster-control") markPlatformFileForDownload("app/portmaster-app") markPlatformFileForDownload("notifier/portmaster-notifier") } updatesLock.Unlock() + // RLock for the remaining function + updatesLock.RLock() + defer updatesLock.RUnlock() + // update existing files log.Tracef("updates: updating existing files") - updatesLock.RLock() - for identifier, newVersion := range newStableUpdates { + for identifier, newVersion := range stableUpdates { oldVersion, ok := localUpdates[identifier] if ok && newVersion != oldVersion { @@ -103,24 +140,7 @@ func CheckForUpdates() (err error) { } } - updatesLock.RUnlock() log.Tracef("updates: finished updating existing files") - // update stable index - updatesLock.Lock() - stableUpdates = newStableUpdates - updatesLock.Unlock() - - // save stable index - err = ioutil.WriteFile(filepath.Join(updateStoragePath, "stable.json"), data, 0644) - if err != nil { - log.Warningf("updates: failed to save new version of stable.json: %s", err) - } - - // update version status - updatesLock.RLock() - defer updatesLock.RUnlock() - updateStatus(versionClassStable, stableUpdates) - return nil }