authentik/internal/config/config.go

171 lines
3.4 KiB
Go

package config
import (
"fmt"
"io/ioutil"
"net/url"
"os"
"reflect"
"strings"
env "github.com/Netflix/go-env"
"github.com/imdario/mergo"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
var cfg *Config
func Get() *Config {
if cfg == nil {
cfg = defaultConfig()
}
return cfg
}
func defaultConfig() *Config {
return &Config{
Debug: false,
Web: WebConfig{
Listen: "localhost:9000",
ListenTLS: "localhost:9443",
},
Paths: PathsConfig{
Media: "./media",
},
LogLevel: "info",
ErrorReporting: ErrorReportingConfig{
Enabled: false,
DSN: "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8",
SampleRate: 1,
},
}
}
func (c *Config) Setup(paths ...string) {
for _, path := range paths {
err := c.LoadConfig(path)
if err != nil {
log.WithError(err).Info("failed to load config, skipping")
}
}
err := c.fromEnv()
if err != nil {
log.WithError(err).Info("failed to load env vars")
}
c.configureLogger()
}
func (c *Config) LoadConfig(path string) error {
raw, err := ioutil.ReadFile(path)
if err != nil {
return fmt.Errorf("Failed to load config file: %w", err)
}
nc := Config{}
err = yaml.Unmarshal(raw, &nc)
if err != nil {
return fmt.Errorf("Failed to parse YAML: %w", err)
}
if err := mergo.Merge(c, nc, mergo.WithOverride); err != nil {
return fmt.Errorf("failed to overlay config: %w", err)
}
c.walkScheme(c)
log.WithField("path", path).Debug("Loaded config")
return nil
}
func (c *Config) fromEnv() error {
nc := Config{}
_, err := env.UnmarshalFromEnviron(&nc)
if err != nil {
return fmt.Errorf("failed to load environment variables: %w", err)
}
if err := mergo.Merge(c, nc, mergo.WithOverride); err != nil {
return fmt.Errorf("failed to overlay config: %w", err)
}
c.walkScheme(c)
log.Debug("Loaded config from environment")
return nil
}
func (c *Config) walkScheme(v interface{}) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return
}
rv = rv.Elem()
if rv.Kind() != reflect.Struct {
return
}
t := rv.Type()
for i := 0; i < rv.NumField(); i++ {
valueField := rv.Field(i)
switch valueField.Kind() {
case reflect.Struct:
if !valueField.Addr().CanInterface() {
continue
}
iface := valueField.Addr().Interface()
c.walkScheme(iface)
}
typeField := t.Field(i)
if typeField.Type.Kind() != reflect.String {
continue
}
valueField.SetString(c.parseScheme(valueField.String()))
}
}
func (c *Config) parseScheme(rawVal string) string {
u, err := url.Parse(rawVal)
if err != nil {
return rawVal
}
if u.Scheme == "env" {
e, ok := os.LookupEnv(u.Host)
if ok {
return e
}
return u.RawQuery
} else if u.Scheme == "file" {
d, err := ioutil.ReadFile(u.Path)
if err != nil {
return u.RawQuery
}
return string(d)
}
return rawVal
}
func (c *Config) configureLogger() {
switch strings.ToLower(c.LogLevel) {
case "trace":
log.SetLevel(log.TraceLevel)
case "debug":
log.SetLevel(log.DebugLevel)
case "info":
log.SetLevel(log.InfoLevel)
case "warning":
log.SetLevel(log.WarnLevel)
case "error":
log.SetLevel(log.ErrorLevel)
default:
log.SetLevel(log.DebugLevel)
}
fm := log.FieldMap{
log.FieldKeyMsg: "event",
log.FieldKeyTime: "timestamp",
}
if c.Debug {
log.SetFormatter(&log.TextFormatter{FieldMap: fm})
} else {
log.SetFormatter(&log.JSONFormatter{FieldMap: fm, DisableHTMLEscape: true})
}
}