diff --git a/cli/agent/command.go b/cli/agent/command.go index 9a09f292..d61055ba 100644 --- a/cli/agent/command.go +++ b/cli/agent/command.go @@ -81,6 +81,9 @@ Examples: if conf.Listeners[i].Protocol == "" { conf.Listeners[i].Protocol = config.ListenerProtocolHTTP } + if conf.Listeners[i].AccessLog.Level == "" { + conf.Listeners[i].AccessLog.Level = "info" + } } if err := conf.Validate(); err != nil { diff --git a/cli/agent/http.go b/cli/agent/http.go index 9c65d309..8ade40e5 100644 --- a/cli/agent/http.go +++ b/cli/agent/http.go @@ -35,11 +35,12 @@ Examples: `, } - accessLog := log.AccessLogConfig{ + accessLogConfig := log.AccessLogConfig{ + Level: "info", Disable: false, } flags := cmd.Flags() - accessLog.RegisterFlags(flags, "") + accessLogConfig.RegisterFlags(flags, "") var timeout time.Duration flags.DurationVar( @@ -59,7 +60,7 @@ Timeout forwarding incoming HTTP requests to the upstream.`, EndpointID: args[0], Addr: args[1], Protocol: config.ListenerProtocolHTTP, - AccessLog: accessLog, + AccessLog: accessLogConfig, Timeout: timeout, }} diff --git a/cli/agent/tcp.go b/cli/agent/tcp.go index fae82b92..38634520 100644 --- a/cli/agent/tcp.go +++ b/cli/agent/tcp.go @@ -32,13 +32,25 @@ Examples: `, } - var disableAccessLogging bool + accessLogConfig := log.AccessLogConfig{ + Level: "info", + Disable: false, + } + cmd.Flags().StringVar( + &accessLogConfig.Level, + "access-log.level", + accessLogConfig.Level, + ` +The record log level for audit log entries. + +The available levels are 'debug', 'info', 'warn' and 'error'.`, + ) cmd.Flags().BoolVar( - &disableAccessLogging, + &accessLogConfig.Disable, "access-log.disable", - false, + accessLogConfig.Disable, ` -Disables logging all incoming connections as 'info' logs. For more options, use a configuration file.`, +Disable the access log, so requests will not be logged.`, ) var timeout time.Duration @@ -59,10 +71,8 @@ Timeout connecting to the upstream.`, EndpointID: args[0], Addr: args[1], Protocol: config.ListenerProtocolTCP, - AccessLog: log.AccessLogConfig{ - Disable: disableAccessLogging, - }, - Timeout: timeout, + AccessLog: accessLogConfig, + Timeout: timeout, }} var err error diff --git a/pkg/log/config.go b/pkg/log/config.go index bbe23b11..c6634273 100644 --- a/pkg/log/config.go +++ b/pkg/log/config.go @@ -20,7 +20,7 @@ func (c *Config) Validate() error { if c.Level == "" { return fmt.Errorf("missing level") } - if _, err := zapLevelFromString(c.Level); err != nil { + if _, err := ZapLevelFromString(c.Level); err != nil { return err } return nil @@ -52,51 +52,70 @@ Such as you can enable 'gossip' logs with '--log.subsystems gossip'.`, } type AccessLogHeaderConfig struct { - // Prevent these headers from being logged. - // You can only define one of Allowlist or Blocklist. - Blocklist []string `json:"blocklist" yaml:"blocklist"` - - // Log only these headers. - // You can only define one of Allowlist or Blocklist. - Allowlist []string `json:"allowlist" yaml:"allowlist"` + // BlockList contains headers that will be redacted from the audit log. + // + // You must only define one of AllowList or BlockList. + BlockList []string `json:"block_list" yaml:"block_list"` + + // AllowList contains the ONLY headers that will be logged. + // + // You must only define one of AllowList or BlockList. + AllowList []string `json:"allow_list" yaml:"allow_list"` } func (c *AccessLogHeaderConfig) Validate() error { - if len(c.Allowlist) > 0 && len(c.Blocklist) > 0 { - return fmt.Errorf("cannot define both allowlist and blocklist") + if len(c.AllowList) > 0 && len(c.BlockList) > 0 { + return fmt.Errorf("cannot define both allow list and block list") } - return nil } func (c *AccessLogHeaderConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { fs.StringSliceVar( - &c.Allowlist, - prefix+"allowlist", - c.Allowlist, + &c.BlockList, + prefix+"block-list", + c.BlockList, ` -Log only these headers`, +Block these headers from being logged. + +You must only define one of block list and allow list.`, ) fs.StringSliceVar( - &c.Blocklist, - prefix+"blocklist", - c.Blocklist, + &c.AllowList, + prefix+"allow-list", + c.AllowList, ` -Block these headers from being logged`, +The ONLY headers that will be logged. + +You must only define one of block list and allow list.`, ) } type AccessLogConfig struct { - // If disabled, logs will be emitted with the 'debug' log level, - // while respecting the header allow and block lists. - Disable bool `json:"disable" yaml:"disable"` + // Level is the record log level for audit log entries. Either 'debug', + // 'info', 'warn' or 'error'. + Level string `json:"level" yaml:"level"` RequestHeaders AccessLogHeaderConfig `json:"request_headers" yaml:"request_headers"` ResponseHeaders AccessLogHeaderConfig `json:"response_headers" yaml:"response_headers"` + + // Disable disables the access log, so requests will not be logged. + Disable bool `json:"disable" yaml:"disable"` } func (c *AccessLogConfig) Validate() error { + if c.Disable { + return nil + } + + if c.Level == "" { + return fmt.Errorf("missing level") + } + if _, err := ZapLevelFromString(c.Level); err != nil { + return err + } + if err := c.RequestHeaders.Validate(); err != nil { return fmt.Errorf("request headers: %w", err) } @@ -113,13 +132,22 @@ func (c *AccessLogConfig) RegisterFlags(fs *pflag.FlagSet, prefix string) { } else { prefix = "access-log." } + fs.StringVar( + &c.Level, + prefix+"level", + c.Level, + ` +The record log level for audit log entries. + +The available levels are 'debug', 'info', 'warn' and 'error'.`, + ) + c.RequestHeaders.RegisterFlags(fs, prefix+"request-headers.") + c.ResponseHeaders.RegisterFlags(fs, prefix+"response-headers.") fs.BoolVar( &c.Disable, prefix+"disable", - false, + c.Disable, ` -If Access logging is disabled`, +Disable the access log, so requests will not be logged.`, ) - c.RequestHeaders.RegisterFlags(fs, prefix+"request-headers.") - c.ResponseHeaders.RegisterFlags(fs, prefix+"response-headers.") } diff --git a/pkg/log/logger.go b/pkg/log/logger.go index 3fe7d84f..c27528bc 100644 --- a/pkg/log/logger.go +++ b/pkg/log/logger.go @@ -39,6 +39,7 @@ type Logger interface { Info(msg string, fields ...zap.Field) Warn(msg string, fields ...zap.Field) Error(msg string, fields ...zap.Field) + Log(level zapcore.Level, msg string, fields ...zap.Field) Sync() error // StdLogger returns a standard library log.Logger that logs records using // with the given level. @@ -58,7 +59,7 @@ type logger struct { // NewLogger creates a new logger filtering using the given log level and // enabled subsystems. func NewLogger(lvl string, enabledSubsystems []string) (Logger, error) { - zapLevel, err := zapLevelFromString(lvl) + zapLevel, err := ZapLevelFromString(lvl) if err != nil { return nil, err } @@ -114,25 +115,23 @@ func (l *logger) With(fields ...zap.Field) Logger { } func (l *logger) Debug(msg string, fields ...zap.Field) { - if ce := l.check(zap.DebugLevel, msg); ce != nil { - ce.Write(fields...) - } + l.Log(zap.DebugLevel, msg, fields...) } func (l *logger) Info(msg string, fields ...zap.Field) { - if ce := l.check(zap.InfoLevel, msg); ce != nil { - ce.Write(fields...) - } + l.Log(zap.InfoLevel, msg, fields...) } func (l *logger) Warn(msg string, fields ...zap.Field) { - if ce := l.check(zap.WarnLevel, msg); ce != nil { - ce.Write(fields...) - } + l.Log(zap.WarnLevel, msg, fields...) } func (l *logger) Error(msg string, fields ...zap.Field) { - if ce := l.check(zap.ErrorLevel, msg); ce != nil { + l.Log(zap.ErrorLevel, msg, fields...) +} + +func (l *logger) Log(level zapcore.Level, msg string, fields ...zap.Field) { + if ce := l.check(level, msg); ce != nil { ce.Write(fields...) } } @@ -214,6 +213,9 @@ func (l *nopLogger) Warn(_ string, _ ...zap.Field) { func (l *nopLogger) Error(_ string, _ ...zap.Field) { } +func (l *nopLogger) Log(_ zapcore.Level, _ string, _ ...zap.Field) { +} + func (l *nopLogger) Sync() error { return nil } @@ -234,7 +236,7 @@ func subsystemMatch(subsystem string, enabled []string) bool { return false } -func zapLevelFromString(s string) (zapcore.Level, error) { +func ZapLevelFromString(s string) (zapcore.Level, error) { switch s { case "debug": return zap.DebugLevel, nil diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go index 638e2029..4ef40f0e 100644 --- a/pkg/middleware/logger.go +++ b/pkg/middleware/logger.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "go.uber.org/zap" + "go.uber.org/zap/zapcore" "github.com/andydunstall/piko/pkg/log" ) @@ -24,32 +25,92 @@ type loggedRequest struct { } type logHeaderFilter struct { - allowList map[string]string - blockList map[string]string + allowList map[string]struct{} + blockList map[string]struct{} } -type loggerConfig struct { - RequestHeader logHeaderFilter - ResponseHeader logHeaderFilter +func newLogHeaderFilter(allowList []string, blockList []string) logHeaderFilter { + var filter logHeaderFilter + if len(allowList) > 0 { + filter.allowList = make(map[string]struct{}) + for _, header := range allowList { + filter.allowList[textproto.CanonicalMIMEHeaderKey(header)] = struct{}{} + } + } + + if len(blockList) > 0 { + filter.blockList = make(map[string]struct{}) + for _, header := range blockList { + filter.blockList[textproto.CanonicalMIMEHeaderKey(header)] = struct{}{} + } + } + + return filter +} + +// Filter filters the given headers based on the allow list and block list. +// +// Note this WILL modify the given headers. +func (l *logHeaderFilter) Filter(h http.Header) http.Header { + if len(l.allowList) > 0 { + for name := range h { + if _, ok := l.allowList[name]; !ok { + h.Del(name) + } + } + return h + } + + if len(l.blockList) > 0 { + for name := range h { + if _, ok := l.blockList[name]; ok { + h.Del(name) + } + } + return h + } + + return h } -// NewLogger creates logging middleware that logs every request. -func NewLogger(config log.AccessLogConfig, l log.Logger) gin.HandlerFunc { - l = l.WithSubsystem(l.Subsystem() + ".access") +// NewLogger creates logging middleware for the access log. +func NewLogger(config log.AccessLogConfig, logger log.Logger) gin.HandlerFunc { + logger = logger.WithSubsystem(logger.Subsystem() + ".access") + + requestHeaderFilter := newLogHeaderFilter( + config.RequestHeaders.AllowList, + config.RequestHeaders.BlockList, + ) + responseHeaderFilter := newLogHeaderFilter( + config.ResponseHeaders.AllowList, + config.ResponseHeaders.BlockList, + ) + + level, err := log.ZapLevelFromString(config.Level) + if err != nil { + // Validated on boot so must not happen. + panic("invalid log level") + } - lc := newLoggerConfig(config) return func(c *gin.Context) { s := time.Now() c.Next() + if config.Disable { + // Access log disabled. + return + } + // Ignore internal endpoints. if strings.HasPrefix(c.Request.URL.Path, "/_piko") { return } - requestHeaders := lc.RequestHeader.Filter(c.Request.Header) - responseHeaders := lc.ResponseHeader.Filter(c.Writer.Header()) + // Note filter will modify the request/response headers, though + // they have already been written so it doesn't matter. + requestHeaders := requestHeaderFilter.Filter(c.Request.Header) + responseHeaders := responseHeaderFilter.Filter(c.Writer.Header()) req := &loggedRequest{ Proto: c.Request.Proto, @@ -61,58 +122,14 @@ func NewLogger(config log.AccessLogConfig, l log.Logger) gin.HandlerFunc { Status: c.Writer.Status(), Duration: time.Since(s).String(), } - if c.Writer.Status() >= http.StatusInternalServerError { - l.Warn("request", zap.Any("request", req)) - } else if config.Disable { - l.Debug("request", zap.Any("request", req)) - } else { - l.Info("request", zap.Any("request", req)) - } - } -} - -func (l *logHeaderFilter) New(allowList []string, blockList []string) { - if len(allowList) > 0 { - l.allowList = make(map[string]string) - for _, el := range allowList { - h := textproto.CanonicalMIMEHeaderKey(el) - l.allowList[h] = h - } - } - - if len(blockList) > 0 { - l.blockList = make(map[string]string) - for _, el := range blockList { - h := textproto.CanonicalMIMEHeaderKey(el) - l.blockList[h] = h - } - } -} -func (l *logHeaderFilter) Filter(h http.Header) http.Header { - if len(l.allowList) > 0 { - for name := range h { - // Use the map created during validation to hasten lookups. - if _, ok := l.allowList[name]; !ok { - h.Del(name) - } + recordLevel := level + // If the response is a server error, increase the log level to a + // minimum of 'warn'. + if c.Writer.Status() >= http.StatusInternalServerError && recordLevel < zapcore.WarnLevel { + recordLevel = zapcore.WarnLevel } - return h - } - if len(l.blockList) > 0 { - for _, blocked := range l.blockList { - h.Del(blocked) - } - return h + logger.Log(recordLevel, "request", zap.Any("request", req)) } - - return h -} - -func newLoggerConfig(c log.AccessLogConfig) loggerConfig { - l := loggerConfig{} - l.RequestHeader.New(c.RequestHeaders.Allowlist, c.RequestHeaders.Blocklist) - l.ResponseHeader.New(c.ResponseHeaders.Allowlist, c.ResponseHeaders.Blocklist) - return l } diff --git a/server/config/config.go b/server/config/config.go index 732bb726..e7191f58 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -548,6 +548,7 @@ func Default() *Config { BindAddr: ":8000", Timeout: time.Second * 30, AccessLog: log.AccessLogConfig{ + Level: "info", Disable: false, }, HTTP: HTTPConfig{ diff --git a/server/config/config_test.go b/server/config/config_test.go index 5063f442..93d15e03 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -29,15 +29,16 @@ proxy: advertise_addr: 1.2.3.4:8000 timeout: 20s access_log: - disable: true + level: debug request_headers: - blocklist: + block_list: - abc - xyz response_headers: - allowlist: + allow_list: - def - ghi + disable: false http: read_timeout: 5s @@ -143,13 +144,14 @@ grace_period: 2m AdvertiseAddr: "1.2.3.4:8000", Timeout: time.Second * 20, AccessLog: log.AccessLogConfig{ - Disable: true, + Level: "debug", RequestHeaders: log.AccessLogHeaderConfig{ - Blocklist: []string{"abc", "xyz"}, + BlockList: []string{"abc", "xyz"}, }, ResponseHeaders: log.AccessLogHeaderConfig{ - Allowlist: []string{"def", "ghi"}, + AllowList: []string{"def", "ghi"}, }, + Disable: false, }, HTTP: HTTPConfig{ ReadTimeout: time.Second * 5, @@ -256,9 +258,10 @@ func TestConfig_LoadFlags(t *testing.T) { "--proxy.bind-addr", "10.15.104.25:8000", "--proxy.advertise-addr", "1.2.3.4:8000", "--proxy.timeout", "20s", + "--proxy.access-log.level", "debug", + "--proxy.access-log.request-headers.allow-list", "abc,def", + "--proxy.access-log.response-headers.block-list", "xyz,ghi", "--proxy.access-log.disable", - "--proxy.access-log.request-headers.allowlist", "abc,def", - "--proxy.access-log.response-headers.blocklist", "xyz,ghi", "--proxy.http.read-timeout", "5s", "--proxy.http.read-header-timeout", "5s", "--proxy.http.write-timeout", "5s", @@ -319,13 +322,14 @@ func TestConfig_LoadFlags(t *testing.T) { AdvertiseAddr: "1.2.3.4:8000", Timeout: time.Second * 20, AccessLog: log.AccessLogConfig{ - Disable: true, + Level: "debug", RequestHeaders: log.AccessLogHeaderConfig{ - Allowlist: []string{"abc", "def"}, + AllowList: []string{"abc", "def"}, }, ResponseHeaders: log.AccessLogHeaderConfig{ - Blocklist: []string{"xyz", "ghi"}, + BlockList: []string{"xyz", "ghi"}, }, + Disable: true, }, HTTP: HTTPConfig{ ReadTimeout: time.Second * 5,