Skip to content

Commit

Permalink
Merge pull request #38 from SkynetLabs/ivo/more_testing
Browse files Browse the repository at this point in the history
Test LoadConfig and Logger
  • Loading branch information
ro-tex authored Jul 8, 2022
2 parents 985ee64 + d6fcaf4 commit 97da18e
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ all: release
count = 1
# pkgs changes which packages the makefile calls operate on. run changes which
# tests are run during testing.
pkgs = ./ ./api ./conf ./database ./skyd ./test ./workers
pkgs = ./ ./api ./conf ./database ./logger ./skyd ./test ./workers

# integration-pkgs defines the packages which contain integration tests
integration-pkgs = ./test ./test/api ./test/database
Expand Down
6 changes: 3 additions & 3 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"time"

"github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus"
"github.com/skynetlabs/pinner/database"
"github.com/skynetlabs/pinner/logger"
"github.com/skynetlabs/pinner/skyd"
"gitlab.com/NebulousLabs/errors"
"gitlab.com/SkynetLabs/skyd/build"
Expand All @@ -20,7 +20,7 @@ type (
API struct {
staticServerName string
staticDB *database.DB
staticLogger *logrus.Logger
staticLogger logger.ExtFieldLogger
staticRouter *httprouter.Router
staticSkydClient skyd.Client

Expand All @@ -43,7 +43,7 @@ type (
)

// New returns a new initialised API.
func New(serverName string, db *database.DB, logger *logrus.Logger, skydClient skyd.Client) (*API, error) {
func New(serverName string, db *database.DB, logger logger.ExtFieldLogger, skydClient skyd.Client) (*API, error) {
if db == nil {
return nil, errors.New("no DB provided")
}
Expand Down
28 changes: 17 additions & 11 deletions conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/joho/godotenv"
"github.com/sirupsen/logrus"
"github.com/skynetlabs/pinner/database"
"gitlab.com/NebulousLabs/errors"
"gitlab.com/SkynetLabs/skyd/build"
Expand All @@ -21,7 +22,7 @@ const (
defaultAccountsHost = "10.10.10.70"
defaultAccountsPort = "3000"
defaultLogFile = "" // disabled logging to file
defaultLogLevel = "info"
defaultLogLevel = logrus.InfoLevel
defaultSiaAPIHost = "10.10.10.10"
defaultSiaAPIPort = "9980"
defaultMinPinners = 1
Expand Down Expand Up @@ -69,7 +70,7 @@ type (
// not log to a file.
LogFile string
// LogLevel defines the logging level of the entire service.
LogLevel string
LogLevel logrus.Level
// MinPinners defines the minimum number of pinning servers
// which a skylink needs in order to not be considered underpinned.
// Anything below this value requires more servers to pin the skylink.
Expand Down Expand Up @@ -97,14 +98,15 @@ func LoadConfig() (Config, error) {

// Start with the default values.
cfg := Config{
AccountsHost: defaultAccountsHost,
AccountsPort: defaultAccountsPort,
DBCredentials: database.DBCredentials{},
LogFile: defaultLogFile,
LogLevel: defaultLogLevel,
MinPinners: defaultMinPinners,
SiaAPIHost: defaultSiaAPIHost,
SiaAPIPort: defaultSiaAPIPort,
AccountsHost: defaultAccountsHost,
AccountsPort: defaultAccountsPort,
DBCredentials: database.DBCredentials{},
LogFile: defaultLogFile,
LogLevel: defaultLogLevel,
MinPinners: defaultMinPinners,
SiaAPIHost: defaultSiaAPIHost,
SiaAPIPort: defaultSiaAPIPort,
SleepBetweenScans: 0, // This will be ignored by the scanner.
}

var ok bool
Expand Down Expand Up @@ -141,7 +143,11 @@ func LoadConfig() (Config, error) {
cfg.LogFile = val
}
if val, ok = os.LookupEnv("PINNER_LOG_LEVEL"); ok {
cfg.LogLevel = val
lvl, err := logrus.ParseLevel(val)
if err != nil {
log.Fatalf("PINNER_LOG_LEVEL has an invalid value of '%s'", val)
}
cfg.LogLevel = lvl
}
if val, ok = os.LookupEnv("PINNER_SLEEP_BETWEEN_SCANS"); ok {
// Check for a bare number and interpret that as seconds.
Expand Down
170 changes: 170 additions & 0 deletions conf/configuration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package conf

import (
"encoding/hex"
"math"
"os"
"testing"
"time"

"github.com/sirupsen/logrus"
"gitlab.com/NebulousLabs/fastrand"
)

// TestLoadConfig ensures that LoadConfig works as expected.
func TestLoadConfig(t *testing.T) {
envVarsReq := []string{
"SERVER_DOMAIN",
"SKYNET_DB_USER",
"SKYNET_DB_PASS",
"SKYNET_DB_HOST",
"SKYNET_DB_PORT",
"SIA_API_PASSWORD",
}
envVarsOpt := []string{
"SKYNET_ACCOUNTS_HOST",
"SKYNET_ACCOUNTS_PORT",
"PINNER_LOG_FILE",
"PINNER_LOG_LEVEL",
"PINNER_SLEEP_BETWEEN_SCANS",
"API_HOST",
"API_PORT",
}
envVars := append(envVarsReq, envVarsOpt...)
// Store all env var values.
values := make(map[string]string)
for _, key := range envVars {
val, exists := os.LookupEnv(key)
if exists {
values[key] = val
}
}
// Set all required vars, so the test will pass even if the environment is
// not fully set.
for _, key := range envVarsReq {
err := os.Setenv(key, key+"value")
if err != nil {
t.Fatal(err)
}
}
// Unset all optional vars.
for _, key := range envVarsOpt {
err := os.Unsetenv(key)
if err != nil {
t.Fatal(err)
}
}
// Set them back up at the end of the test.
defer func(vals map[string]string) {
for _, key := range envVars {
val, exists := vals[key]
if exists {
err := os.Setenv(key, val)
if err != nil {
t.Error(err)
}
} else {
err := os.Unsetenv(key)
if err != nil {
t.Error(err)
}
}
}
}(values)
// Get the values without setting any optionals.
cfg, err := LoadConfig()
if err != nil {
t.Fatal(err)
}
// Ensure the required ones match the environment.
if cfg.ServerName != os.Getenv("SERVER_DOMAIN") {
t.Fatal("Bad SERVER_DOMAIN")
}
if cfg.DBCredentials.User != os.Getenv("SKYNET_DB_USER") {
t.Fatal("Bad SKYNET_DB_USER")
}
if cfg.DBCredentials.Password != os.Getenv("SKYNET_DB_PASS") {
t.Fatal("Bad SKYNET_DB_PASS")
}
if cfg.DBCredentials.Host != os.Getenv("SKYNET_DB_HOST") {
t.Fatal("Bad SKYNET_DB_HOST")
}
if cfg.DBCredentials.Port != os.Getenv("SKYNET_DB_PORT") {
t.Fatal("Bad SKYNET_DB_PORT")
}
if cfg.SiaAPIPassword != os.Getenv("SIA_API_PASSWORD") {
t.Fatal("Bad SIA_API_PASSWORD")
}
// Ensure the optional ones have their default values.
if cfg.AccountsHost != defaultAccountsHost {
t.Fatal("Bad AccountsHost")
}
if cfg.AccountsPort != defaultAccountsPort {
t.Fatal("Bad AccountsPort")
}
if cfg.LogFile != defaultLogFile {
t.Fatal("Bad LogFile")
}
if cfg.LogLevel != defaultLogLevel {
t.Fatal("Bad LogLevel")
}
if cfg.SleepBetweenScans != 0 {
t.Fatal("Bad SleepBetweenScans")
}
if cfg.SiaAPIHost != defaultSiaAPIHost {
t.Fatal("Bad SiaAPIHost")
}
if cfg.SiaAPIPort != defaultSiaAPIPort {
t.Fatal("Bad SiaAPIPort")
}

// Set the optionals to custom values.
optionalValues := make(map[string]string)
for _, key := range envVarsOpt {
optionalValues[key] = hex.EncodeToString(fastrand.Bytes(16))
err = os.Setenv(key, optionalValues[key])
if err != nil {
t.Fatal(err)
}
}
// We'll set a special value for PINNER_SLEEP_BETWEEN_SCANS and
// PINNER_LOG_LEVEL because they need to have valid values.
optionalValues["PINNER_SLEEP_BETWEEN_SCANS"] = time.Duration(fastrand.Intn(math.MaxInt)).String()
err = os.Setenv("PINNER_SLEEP_BETWEEN_SCANS", optionalValues["PINNER_SLEEP_BETWEEN_SCANS"])
if err != nil {
t.Fatal(err)
}
// Random log level between 0 (Panic) and 7 (Trace).
optionalValues["PINNER_LOG_LEVEL"] = logrus.Level(fastrand.Intn(int(logrus.TraceLevel) + 1)).String()
err = os.Setenv("PINNER_LOG_LEVEL", optionalValues["PINNER_LOG_LEVEL"])
if err != nil {
t.Fatal(err)
}
// Load the config again.
cfg, err = LoadConfig()
if err != nil {
t.Fatal(err)
}
// Ensure all optionals got the custom values we set for them.
if cfg.AccountsHost != optionalValues["SKYNET_ACCOUNTS_HOST"] {
t.Fatal("Bad AccountsHost")
}
if cfg.AccountsPort != optionalValues["SKYNET_ACCOUNTS_PORT"] {
t.Fatal("Bad AccountsPort")
}
if cfg.LogFile != optionalValues["PINNER_LOG_FILE"] {
t.Fatal("Bad LogFile")
}
if cfg.LogLevel.String() != optionalValues["PINNER_LOG_LEVEL"] {
t.Fatal("Bad LogLevel")
}
if tm, err := time.ParseDuration(optionalValues["PINNER_SLEEP_BETWEEN_SCANS"]); err != nil || cfg.SleepBetweenScans != tm {
t.Fatal("Bad SleepBetweenScans")
}
if cfg.SiaAPIHost != optionalValues["API_HOST"] {
t.Fatal("Bad SiaAPIHost")
}
if cfg.SiaAPIPort != optionalValues["API_PORT"] {
t.Fatal("Bad SiaAPIPort")
}
}
10 changes: 5 additions & 5 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
"time"

"github.com/sirupsen/logrus"
"github.com/skynetlabs/pinner/logger"
"gitlab.com/NebulousLabs/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
Expand Down Expand Up @@ -41,7 +41,7 @@ type (
DB struct {
staticCtx context.Context
staticDB *mongo.Database
staticLogger *logrus.Logger
staticLogger logger.ExtFieldLogger
}

// DBCredentials is a helper struct that binds together all values needed for
Expand All @@ -55,12 +55,12 @@ type (
)

// New creates a new database connection.
func New(ctx context.Context, creds DBCredentials, logger *logrus.Logger) (*DB, error) {
func New(ctx context.Context, creds DBCredentials, logger logger.ExtFieldLogger) (*DB, error) {
return NewCustomDB(ctx, dbName, creds, logger)
}

// NewCustomDB creates a new database connection to a database with a custom name.
func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger *logrus.Logger) (*DB, error) {
func NewCustomDB(ctx context.Context, dbName string, creds DBCredentials, logger logger.ExtFieldLogger) (*DB, error) {
if ctx == nil {
return nil, errors.New("invalid context provided")
}
Expand Down Expand Up @@ -144,7 +144,7 @@ func (db *DB) SetConfigValue(ctx context.Context, key, value string) error {
// creates them if needed.
// See https://docs.mongodb.com/manual/indexes/
// See https://docs.mongodb.com/manual/core/index-unique/
func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) error {
func ensureDBSchema(ctx context.Context, db *mongo.Database, log logger.ExtFieldLogger) error {
for collName, models := range schema() {
coll, err := ensureCollection(ctx, db, collName)
if err != nil {
Expand Down
61 changes: 61 additions & 0 deletions logger/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package logger

import (
"io"
"os"

"github.com/sirupsen/logrus"
"gitlab.com/NebulousLabs/errors"
)

type (
// ExtFieldLogger defines the logger interface we need.
//
// It is identical to logrus.Ext1FieldLogger but we are not using that
// because it's marked as "Do not use". Instead, we're defining our own in
// order to be sure that potential Logrus changes won't break us.
ExtFieldLogger interface {
logrus.FieldLogger
Tracef(format string, args ...interface{})
Trace(args ...interface{})
Traceln(args ...interface{})
}

// Logger is a wrapper of *logrus.Logger which allows logging to a file on
// disk.
Logger struct {
*logrus.Logger
logFile *os.File
}
)

// New creates a new logger that can optionally write to disk.
//
// If the given logfile argument is an empty string, the logger will not write
// to disk.
func New(level logrus.Level, logfile string) (logger *Logger, err error) {
logger = &Logger{
logrus.New(),
nil,
}
logger.SetLevel(level)
// Open and start writing to the log file, unless we have an empty string,
// which signifies "don't log to disk".
if logfile != "" {
logger.logFile, err = os.OpenFile(logfile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return nil, errors.AddContext(err, "failed to open log file")
}

logger.SetOutput(io.MultiWriter(os.Stdout, logger.logFile))
}
return logger, nil
}

// Close gracefully closes all resources used by Logger.
func (l *Logger) Close() error {
if l.logFile == nil {
return nil
}
return l.logFile.Close()
}
Loading

0 comments on commit 97da18e

Please sign in to comment.