diff --git a/app/app.go b/app/app.go index 48e3886..395c4eb 100644 --- a/app/app.go +++ b/app/app.go @@ -1,7 +1,9 @@ package app import ( + "bytes" "context" + "errors" "fmt" "log/slog" "math" @@ -10,11 +12,13 @@ import ( "strings" "time" + "github.com/mandelsoft/vfs/pkg/vfs" "github.com/nrednav/cuid2" "go.hackfix.me/disco/app/cli" actx "go.hackfix.me/disco/app/context" aerrors "go.hackfix.me/disco/app/errors" + "go.hackfix.me/disco/crypto" "go.hackfix.me/disco/db" "go.hackfix.me/disco/db/queries" "go.hackfix.me/disco/db/store" @@ -89,10 +93,32 @@ func (app *App) Run(args []string) error { // so prevent it by using SQLite's in-memory support. storeDir = ":memory:" } - if err := app.initStores(storeDir); err != nil { + + var encKey *[32]byte + if app.ctx.User != nil { + encKey = app.ctx.User.PrivateKey + } + cmd := app.cli.Command() + // Only read the encryption for specific commands. + encKeyCommands := []string{"get", "set", "ls", "serve", "invite user", "remote add"} + if encKey == nil && slices.Contains(encKeyCommands, cmd) { + var err error + encKey, err = app.readEncryptionKey() + if err != nil { + return aerrors.NewRuntimeError("invalid encryption key", err, "") + } + } + + if err := app.initStores(storeDir, encKey); err != nil { return err } + if app.ctx.User == nil && cmd != "init" { + if err := app.ctx.LoadLocalUser(encKey); err != nil { + return err + } + } + if err := app.cli.Execute(app.ctx); err != nil { return err } @@ -100,6 +126,24 @@ func (app *App) Run(args []string) error { return nil } +func (app *App) readEncryptionKey() (*[32]byte, error) { + encKey, err := crypto.DecodeKey(app.cli.EncryptionKey) + if err != nil { + // Maybe it's a file path + encKeyData, fsErr := vfs.ReadFile(app.ctx.FS, app.cli.EncryptionKey) + if fsErr != nil { + // Unwrap error to avoid potentially logging a secret. + return nil, errors.Unwrap(fsErr) + } + encKey, err = crypto.DecodeKey(string(bytes.TrimSpace(encKeyData))) + if err != nil { + return nil, err + } + } + + return encKey, nil +} + func (app *App) createDataDir(dir string) error { err := app.ctx.FS.MkdirAll(dir, 0o700) if err != nil { @@ -109,7 +153,9 @@ func (app *App) createDataDir(dir string) error { return nil } -func (app *App) initStores(dataDir string) error { +// TODO: Remove encryption key from here. Instead load it only when needed, +// to minimize the amount of time it's stored in RAM. +func (app *App) initStores(dataDir string, encKey *[32]byte) error { var err error if app.ctx.DB == nil { app.ctx.DB, err = initDB(app.ctx.Ctx, dataDir) @@ -123,25 +169,7 @@ func (app *App) initStores(dataDir string) error { app.ctx.VersionInit = version.V } - // Only load the local user if it's not set and we're currrently not - // initializing. If we're initializing, the migrations haven't been run at - // this point, so the schema doesn't exist yet. - cmd := app.cli.Command() - if app.ctx.User == nil && cmd != "init" { - // The encryption key is only required for specific commands. - encKeyCommands := []string{"get", "set", "ls", "serve", "invite user", "remote add"} - readEncKey := slices.Contains(encKeyCommands, cmd) - err = app.ctx.LoadLocalUser(readEncKey) - if err != nil { - return err - } - } - if app.ctx.Store == nil { - var encKey *[32]byte - if app.ctx.User != nil { - encKey = app.ctx.User.PrivateKey - } app.ctx.Store, err = initKVStore(app.ctx.Ctx, dataDir, encKey) if err != nil { return err diff --git a/app/cli/cli.go b/app/cli/cli.go index 3b34280..51e725f 100644 --- a/app/cli/cli.go +++ b/app/cli/cli.go @@ -25,9 +25,10 @@ type CLI struct { Invite Invite `kong:"cmd,help='Manage invitations for remote users.'"` Remote Remote `kong:"cmd,help='Manage remote Disco nodes.'"` - Version kong.VersionFlag `kong:"help='Output Disco version and exit.'"` - DataDir string `kong:"default='${dataDir}',help='Directory to store Disco data in.'"` - EncryptionKey string `kong:"help='Private key used for encrypting and decrypting the local data store. '"` + Version kong.VersionFlag `kong:"help='Output Disco version and exit.'"` + DataDir string `kong:"default='${dataDir}',help='Directory to store Disco data in.'"` + //nolint:lll + EncryptionKey string `kong:"help='Private key used for encrypting and decrypting the local data store. \n It can be the value itself or a file path that contains the value. '"` Log struct { Level slog.Level `enum:"DEBUG,INFO,WARN,ERROR" default:"INFO" help:"Set the app logging level."` } `embed:"" prefix:"log-"` diff --git a/app/context/context.go b/app/context/context.go index 20c51c6..a737473 100644 --- a/app/context/context.go +++ b/app/context/context.go @@ -52,12 +52,11 @@ type Environment interface { Set(string, string) error } -// LoadLocalUser loads the local user from the database into c.User. -// If readEncKey is true, it also reads the private encryption key from the -// environment and validates it against its stored hash. +// LoadLocalUser loads the local user from the database into c.User, +// optionally validating the provided encryption key with the stored hash. // Note that this *must* load a single user. Currently only a single local user // is created, but in the future this might change. -func (c *Context) LoadLocalUser(readEncKey bool) error { +func (c *Context) LoadLocalUser(encKey *[32]byte) error { users, err := models.Users(c.DB.NewContext(), c.DB, types.NewFilter("u.type = ?", []any{models.UserTypeLocal})) if err != nil { @@ -75,26 +74,20 @@ func (c *Context) LoadLocalUser(readEncKey bool) error { fmt.Sprintf("found more than 1 local user: %d", len(users)), nil, "") } - if readEncKey { + if encKey != nil { privKeyHash, privKeyErr := queries.GetEncryptionPrivKeyHash(c.DB.NewContext(), c.DB) if privKeyErr != nil || !privKeyHash.Valid { return aerrors.NewRuntimeError("missing encryption key hash", privKeyErr, "Did you forget to run 'disco init'?") } - privKeyEnc := c.Env.Get("DISCO_ENCRYPTION_KEY") - privKey, err := crypto.DecodeKey(privKeyEnc) - if err != nil { - return aerrors.NewRuntimeError("invalid encryption key", err, "") - } - - inPrivKeyHash := crypto.Hash("encryption key hash", privKey[:]) + inPrivKeyHash := crypto.Hash("encryption key hash", encKey[:]) inPrivKeyHashEnc := base58.Encode(inPrivKeyHash) if privKeyHash.V != inPrivKeyHashEnc { return aerrors.NewRuntimeError("invalid encryption key", errors.New("hash mismatch"), "") } - c.User.PrivateKey = privKey + c.User.PrivateKey = encKey } return nil