diff --git a/internal/aghos/service.go b/internal/aghos/service.go new file mode 100644 index 00000000..4be05dd2 --- /dev/null +++ b/internal/aghos/service.go @@ -0,0 +1,6 @@ +package aghos + +// PreCheckActionStart performs the service start action pre-check. +func PreCheckActionStart() (err error) { + return preCheckActionStart() +} diff --git a/internal/aghos/service_darwin.go b/internal/aghos/service_darwin.go new file mode 100644 index 00000000..b87f95ed --- /dev/null +++ b/internal/aghos/service_darwin.go @@ -0,0 +1,32 @@ +//go:build darwin + +package aghos + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/AdguardTeam/golibs/log" +) + +// preCheckActionStart performs the service start action pre-check. It warns +// user that the service should be installed into Applications directory. +func preCheckActionStart() (err error) { + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("getting executable path: %v", err) + } + + exe, err = filepath.EvalSymlinks(exe) + if err != nil { + return fmt.Errorf("evaluating executable symlinks: %v", err) + } + + if !strings.HasPrefix(exe, "/Applications/") { + log.Info("warning: service must be started from within the /Applications directory") + } + + return err +} diff --git a/internal/aghos/service_others.go b/internal/aghos/service_others.go new file mode 100644 index 00000000..0869f53f --- /dev/null +++ b/internal/aghos/service_others.go @@ -0,0 +1,8 @@ +//go:build !darwin + +package aghos + +// preCheckActionStart performs the service start action pre-check. +func preCheckActionStart() (err error) { + return nil +} diff --git a/internal/home/config.go b/internal/home/config.go index bd3d7e49..8d9fa422 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -399,19 +399,23 @@ func (c *configuration) getConfigFilename() string { return configFile } -// getLogSettings reads logging settings from the config file. -// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied. -func getLogSettings() logSettings { - l := logSettings{} +// readLogSettings reads logging settings from the config file. We do it in a +// separate method in order to configure logger before the actual configuration +// is parsed and applied. +func readLogSettings() (ls *logSettings) { + ls = &logSettings{} + yamlFile, err := readConfigFile() if err != nil { - return l + return ls } - err = yaml.Unmarshal(yamlFile, &l) + + err = yaml.Unmarshal(yamlFile, ls) if err != nil { log.Error("Couldn't get logging settings from the configuration: %s", err) } - return l + + return ls } // validateBindHosts returns error if any of binding hosts from configuration is diff --git a/internal/home/home.go b/internal/home/home.go index 150d0011..5f1dd6f2 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -37,6 +37,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/stringutil" "golang.org/x/exp/slices" "gopkg.in/natefinch/lumberjack.v2" ) @@ -145,7 +146,9 @@ func Main(clientBuildFS fs.FS) { run(opts, clientBuildFS) } -func setupContext(opts options) { +// setupContext initializes [Context] fields. It also reads and upgrades +// config file if necessary. +func setupContext(opts options) (err error) { setupContextFlags(opts) Context.tlsRoots = aghtls.SystemRootCAs() @@ -162,10 +165,15 @@ func setupContext(opts options) { }, } + Context.mux = http.NewServeMux() + if !Context.firstRun { // Do the upgrade if necessary. - err := upgradeConfig() - fatalOnError(err) + err = upgradeConfig() + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } if err = parseConfig(); err != nil { log.Error("parsing configuration file: %s", err) @@ -181,11 +189,14 @@ func setupContext(opts options) { if !opts.noEtcHosts && config.Clients.Sources.HostsFile { err = setupHostsContainer() - fatalOnError(err) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } } } - Context.mux = http.NewServeMux() + return nil } // setupContextFlags sets global flags and prints their status to the log. @@ -287,78 +298,27 @@ func setupHostsContainer() (err error) { return nil } -func setupConfig(opts options) (err error) { - config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts - config.DNS.DnsfilterConf.ConfigModified = onConfigModified - config.DNS.DnsfilterConf.HTTPRegister = httpRegister - config.DNS.DnsfilterConf.DataDir = Context.getDataDir() - config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters) - config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters) - config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules) - config.DNS.DnsfilterConf.HTTPClient = Context.client - - const ( - dnsTimeout = 3 * time.Second - - sbService = "safe browsing" - defaultSafeBrowsingServer = `https://family.adguard-dns.com/dns-query` - sbTXTSuffix = `sb.dns.adguard.com.` - - pcService = "parental control" - defaultParentalServer = `https://family.adguard-dns.com/dns-query` - pcTXTSuffix = `pc.dns.adguard.com.` - ) - - cacheTime := time.Duration(config.DNS.DnsfilterConf.CacheTime) * time.Minute - - upsOpts := &upstream.Options{ - Timeout: dnsTimeout, - ServerIPAddrs: []net.IP{ - {94, 140, 14, 15}, - {94, 140, 15, 16}, - net.ParseIP("2a10:50c0::bad1:ff"), - net.ParseIP("2a10:50c0::bad2:ff"), - }, +// setupOpts sets up command-line options. +func setupOpts(opts options) (err error) { + err = setupBindOpts(opts) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err } - sbUps, err := upstream.AddressToUpstream(defaultSafeBrowsingServer, upsOpts) - if err != nil { - return fmt.Errorf("converting safe browsing server: %w", err) + if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) { + Context.pidFileName = opts.pidFile } - safeBrowsing := hashprefix.New(&hashprefix.Config{ - Upstream: sbUps, - ServiceName: sbService, - TXTSuffix: sbTXTSuffix, - CacheTime: cacheTime, - CacheSize: config.DNS.DnsfilterConf.SafeBrowsingCacheSize, - }) + return nil +} - parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts) +// initContextClients initializes Context clients and related fields. +func initContextClients() (err error) { + err = setupDNSFilteringConf(config.DNS.DnsfilterConf) if err != nil { - return fmt.Errorf("converting parental server: %w", err) - } - - parentalControl := hashprefix.New(&hashprefix.Config{ - Upstream: parUps, - ServiceName: pcService, - TXTSuffix: pcTXTSuffix, - CacheTime: cacheTime, - CacheSize: config.DNS.DnsfilterConf.SafeBrowsingCacheSize, - }) - - config.DNS.DnsfilterConf.SafeBrowsingChecker = safeBrowsing - config.DNS.DnsfilterConf.ParentalControlChecker = parentalControl - - config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{} - config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault( - config.DNS.DnsfilterConf.SafeSearchConf, - "default", - config.DNS.DnsfilterConf.SafeSearchCacheSize, - time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime), - ) - if err != nil { - return fmt.Errorf("initializing safesearch: %w", err) + // Don't wrap the error, because it's informative enough as is. + return err } //lint:ignore SA1019 Migration is not over. @@ -393,8 +353,19 @@ func setupConfig(opts options) (err error) { arpdb = aghnet.NewARPDB() } - Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf) + Context.clients.Init( + config.Clients.Persistent, + Context.dhcpServer, + Context.etcHosts, + arpdb, + config.DNS.DnsfilterConf, + ) + return nil +} + +// setupBindOpts overrides bind host/port from the opts. +func setupBindOpts(opts options) (err error) { if opts.bindPort != 0 { config.BindPort = opts.bindPort @@ -405,12 +376,83 @@ func setupConfig(opts options) (err error) { } } - // override bind host/port from the console if opts.bindHost.IsValid() { config.BindHost = opts.bindHost } - if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) { - Context.pidFileName = opts.pidFile + + return nil +} + +// setupDNSFilteringConf sets up DNS filtering configuration settings. +func setupDNSFilteringConf(conf *filtering.Config) (err error) { + const ( + dnsTimeout = 3 * time.Second + + sbService = "safe browsing" + defaultSafeBrowsingServer = `https://family.adguard-dns.com/dns-query` + sbTXTSuffix = `sb.dns.adguard.com.` + + pcService = "parental control" + defaultParentalServer = `https://family.adguard-dns.com/dns-query` + pcTXTSuffix = `pc.dns.adguard.com.` + ) + + conf.EtcHosts = Context.etcHosts + conf.ConfigModified = onConfigModified + conf.HTTPRegister = httpRegister + conf.DataDir = Context.getDataDir() + conf.Filters = slices.Clone(config.Filters) + conf.WhitelistFilters = slices.Clone(config.WhitelistFilters) + conf.UserRules = slices.Clone(config.UserRules) + conf.HTTPClient = Context.client + + cacheTime := time.Duration(conf.CacheTime) * time.Minute + + upsOpts := &upstream.Options{ + Timeout: dnsTimeout, + ServerIPAddrs: []net.IP{ + {94, 140, 14, 15}, + {94, 140, 15, 16}, + net.ParseIP("2a10:50c0::bad1:ff"), + net.ParseIP("2a10:50c0::bad2:ff"), + }, + } + + sbUps, err := upstream.AddressToUpstream(defaultSafeBrowsingServer, upsOpts) + if err != nil { + return fmt.Errorf("converting safe browsing server: %w", err) + } + + conf.SafeBrowsingChecker = hashprefix.New(&hashprefix.Config{ + Upstream: sbUps, + ServiceName: sbService, + TXTSuffix: sbTXTSuffix, + CacheTime: cacheTime, + CacheSize: conf.SafeBrowsingCacheSize, + }) + + parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts) + if err != nil { + return fmt.Errorf("converting parental server: %w", err) + } + + conf.ParentalControlChecker = hashprefix.New(&hashprefix.Config{ + Upstream: parUps, + ServiceName: pcService, + TXTSuffix: pcTXTSuffix, + CacheTime: cacheTime, + CacheSize: conf.SafeBrowsingCacheSize, + }) + + conf.SafeSearchConf.CustomResolver = safeSearchResolver{} + conf.SafeSearch, err = safesearch.NewDefault( + conf.SafeSearchConf, + "default", + conf.SafeSearchCacheSize, + cacheTime, + ) + if err != nil { + return fmt.Errorf("initializing safesearch: %w", err) } return nil @@ -487,14 +529,16 @@ func fatalOnError(err error) { // run configures and starts AdGuard Home. func run(opts options, clientBuildFS fs.FS) { - // configure config filename + // Configure config filename. initConfigFilename(opts) - // configure working dir and config path - initWorkingDir(opts) + // Configure working dir and config path. + err := initWorkingDir(opts) + fatalOnError(err) - // configure log level and output - configureLogger(opts) + // Configure log level and output. + err = configureLogger(opts) + fatalOnError(err) // Print the first message after logger is configured. log.Info(version.Full()) @@ -503,25 +547,29 @@ func run(opts options, clientBuildFS fs.FS) { log.Info("AdGuard Home is running as a service") } - setupContext(opts) - - err := configureOS(config) + err = setupContext(opts) fatalOnError(err) - // clients package uses filtering package's static data (filtering.BlockedSvcKnown()), - // so we have to initialize filtering's static data first, - // but also avoid relying on automatic Go init() function + err = configureOS(config) + fatalOnError(err) + + // Clients package uses filtering package's static data + // (filtering.BlockedSvcKnown()), so we have to initialize filtering static + // data first, but also to avoid relying on automatic Go init() function. filtering.InitModule() - err = setupConfig(opts) + err = initContextClients() fatalOnError(err) - // TODO(e.burkov): This could be made earlier, probably as the option's + err = setupOpts(opts) + fatalOnError(err) + + // TODO(e.burkov): This could be made earlier, probably as the option's // effect. cmdlineUpdate(opts) if !Context.firstRun { - // Save the updated config + // Save the updated config. err = config.write() fatalOnError(err) @@ -531,33 +579,15 @@ func run(opts options, clientBuildFS fs.FS) { } } - err = os.MkdirAll(Context.getDataDir(), 0o755) - if err != nil { - log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err) - } + dir := Context.getDataDir() + err = os.MkdirAll(dir, 0o755) + fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dir)) - sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") GLMode = opts.glinetMode - var rateLimiter *authRateLimiter - if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { - rateLimiter = newAuthRateLimiter( - time.Duration(config.AuthBlockMin)*time.Minute, - config.AuthAttempts, - ) - } else { - log.Info("authratelimiter is disabled") - } - Context.auth = InitAuth( - sessFilename, - config.Users, - config.WebSessionTTLHours*60*60, - rateLimiter, - ) - if Context.auth == nil { - log.Fatalf("Couldn't initialize Auth module") - } - config.Users = nil + // Init auth module. + Context.auth, err = initUsers() + fatalOnError(err) Context.tls, err = newTLSManager(config.TLS) if err != nil { @@ -575,10 +605,10 @@ func run(opts options, clientBuildFS fs.FS) { Context.tls.start() go func() { - serr := startDNSServer() - if serr != nil { + sErr := startDNSServer() + if sErr != nil { closeDNSServer() - fatalOnError(serr) + fatalOnError(sErr) } }() @@ -592,10 +622,33 @@ func run(opts options, clientBuildFS fs.FS) { Context.web.start() - // wait indefinitely for other go-routines to complete their job + // Wait indefinitely for other goroutines to complete their job. select {} } +// initUsers initializes context auth module. Clears config users field. +func initUsers() (auth *Auth, err error) { + sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") + + var rateLimiter *authRateLimiter + if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { + blockDur := time.Duration(config.AuthBlockMin) * time.Minute + rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts) + } else { + log.Info("authratelimiter is disabled") + } + + sessionTTL := config.WebSessionTTLHours * 60 * 60 + auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter) + if auth == nil { + return nil, errors.Error("initializing auth module failed") + } + + config.Users = nil + + return auth, nil +} + func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) { var anonFunc aghnet.IPMutFunc if c.DNS.AnonymizeClientIP { @@ -668,22 +721,19 @@ func writePIDFile(fn string) bool { return true } +// initConfigFilename sets up context config file path. This file path can be +// overridden by command-line arguments, or is set to default. func initConfigFilename(opts options) { - // config file path can be overridden by command-line arguments: - if opts.confFilename != "" { - Context.configFilename = opts.confFilename - } else { - // Default config file name - Context.configFilename = "AdGuardHome.yaml" - } + Context.configFilename = stringutil.Coalesce(opts.confFilename, "AdGuardHome.yaml") } -// initWorkingDir initializes the workDir -// if no command-line arguments specified, we use the directory where our binary file is located -func initWorkingDir(opts options) { +// initWorkingDir initializes the workDir. If no command-line arguments are +// specified, the directory with the binary file is used. +func initWorkingDir(opts options) (err error) { execPath, err := os.Executable() if err != nil { - panic(err) + // Don't wrap the error, because it's informative enough as is. + return err } if opts.workDir != "" { @@ -695,34 +745,20 @@ func initWorkingDir(opts options) { workDir, err := filepath.EvalSymlinks(Context.workDir) if err != nil { - panic(err) + // Don't wrap the error, because it's informative enough as is. + return err } Context.workDir = workDir + + return nil } -// configureLogger configures logger level and output -func configureLogger(opts options) { - ls := getLogSettings() +// configureLogger configures logger level and output. +func configureLogger(opts options) (err error) { + ls := getLogSettings(opts) - // command-line arguments can override config settings - if opts.verbose || config.Verbose { - ls.Verbose = true - } - if opts.logFile != "" { - ls.File = opts.logFile - } else if config.File != "" { - ls.File = config.File - } - - // Handle default log settings overrides - ls.Compress = config.Compress - ls.LocalTime = config.LocalTime - ls.MaxBackups = config.MaxBackups - ls.MaxSize = config.MaxSize - ls.MaxAge = config.MaxAge - - // log.SetLevel(log.INFO) - default + // Configure logger level. if ls.Verbose { log.SetLevel(log.DEBUG) } @@ -731,38 +767,63 @@ func configureLogger(opts options) { // happen pretty quickly. log.SetFlags(log.LstdFlags | log.Lmicroseconds) - if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { - // When running as a Windows service, use eventlog by default if nothing - // else is configured. Otherwise, we'll simply lose the log output. - ls.File = configSyslog - } - - // logs are written to stdout (default) + // Write logs to stdout by default. if ls.File == "" { - return + return nil } if ls.File == configSyslog { - // Use syslog where it is possible and eventlog on Windows - err := aghos.ConfigureSyslog(serviceName) + // Use syslog where it is possible and eventlog on Windows. + err = aghos.ConfigureSyslog(serviceName) if err != nil { - log.Fatalf("cannot initialize syslog: %s", err) - } - } else { - logFilePath := ls.File - if !filepath.IsAbs(logFilePath) { - logFilePath = filepath.Join(Context.workDir, logFilePath) + return fmt.Errorf("cannot initialize syslog: %w", err) } - log.SetOutput(&lumberjack.Logger{ - Filename: logFilePath, - Compress: ls.Compress, // disabled by default - LocalTime: ls.LocalTime, - MaxBackups: ls.MaxBackups, - MaxSize: ls.MaxSize, // megabytes - MaxAge: ls.MaxAge, // days - }) + return nil } + + logFilePath := ls.File + if !filepath.IsAbs(logFilePath) { + logFilePath = filepath.Join(Context.workDir, logFilePath) + } + + log.SetOutput(&lumberjack.Logger{ + Filename: logFilePath, + Compress: ls.Compress, + LocalTime: ls.LocalTime, + MaxBackups: ls.MaxBackups, + MaxSize: ls.MaxSize, + MaxAge: ls.MaxAge, + }) + + return nil +} + +// getLogSettings returns a log settings object properly initialized from opts. +func getLogSettings(opts options) (ls *logSettings) { + ls = readLogSettings() + + // Command-line arguments can override config settings. + if opts.verbose || config.Verbose { + ls.Verbose = true + } + + ls.File = stringutil.Coalesce(opts.logFile, config.File, ls.File) + + // Handle default log settings overrides. + ls.Compress = config.Compress + ls.LocalTime = config.LocalTime + ls.MaxBackups = config.MaxBackups + ls.MaxSize = config.MaxSize + ls.MaxAge = config.MaxAge + + if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { + // When running as a Windows service, use eventlog by default if + // nothing else is configured. Otherwise, we'll lose the log output. + ls.File = configSyslog + } + + return ls } // cleanup stops and resets all the modules. diff --git a/internal/home/service.go b/internal/home/service.go index c0fe845f..3ec44138 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -4,7 +4,6 @@ import ( "fmt" "io/fs" "os" - "path/filepath" "runtime" "strconv" "strings" @@ -84,14 +83,9 @@ func svcStatus(s service.Service) (status service.Status, err error) { // On OpenWrt, the service utility may not exist. We use our service script // directly in this case. func svcAction(s service.Service, action string) (err error) { - if runtime.GOOS == "darwin" && action == "start" { - var exe string - if exe, err = os.Executable(); err != nil { - log.Error("starting service: getting executable path: %s", err) - } else if exe, err = filepath.EvalSymlinks(exe); err != nil { - log.Error("starting service: evaluating executable symlinks: %s", err) - } else if !strings.HasPrefix(exe, "/Applications/") { - log.Info("warning: service must be started from within the /Applications directory") + if action == "start" { + if err = aghos.PreCheckActionStart(); err != nil { + log.Error("starting service: %s", err) } } @@ -99,8 +93,6 @@ func svcAction(s service.Service, action string) (err error) { if err != nil && service.Platform() == "unix-systemv" && (action == "start" || action == "stop" || action == "restart") { _, err = runInitdCommand(action) - - return err } return err @@ -224,6 +216,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) { runOpts := opts runOpts.serviceControlAction = "run" + svcConfig := &service.Config{ Name: serviceName, DisplayName: serviceDisplayName, @@ -233,35 +226,48 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) { } configureService(svcConfig) - prg := &program{ - clientBuildFS: clientBuildFS, - opts: runOpts, - } - var s service.Service - if s, err = service.New(prg, svcConfig); err != nil { + s, err := service.New(&program{clientBuildFS: clientBuildFS, opts: runOpts}, svcConfig) + if err != nil { log.Fatalf("service: initializing service: %s", err) } + err = handleServiceCommand(s, action, opts) + if err != nil { + log.Fatalf("service: %s", err) + } + + log.Printf( + "service: action %s has been done successfully on %s", + action, + service.ChosenSystem(), + ) +} + +// handleServiceCommand handles service command. +func handleServiceCommand(s service.Service, action string, opts options) (err error) { switch action { case "status": handleServiceStatusCommand(s) case "run": if err = s.Run(); err != nil { - log.Fatalf("service: failed to run service: %s", err) + return fmt.Errorf("failed to run service: %w", err) } case "install": initConfigFilename(opts) - initWorkingDir(opts) + if err = initWorkingDir(opts); err != nil { + return fmt.Errorf("failed to init working dir: %w", err) + } + handleServiceInstallCommand(s) case "uninstall": handleServiceUninstallCommand(s) default: if err = svcAction(s, action); err != nil { - log.Fatalf("service: executing action %q: %s", action, err) + return fmt.Errorf("executing action %q: %w", action, err) } } - log.Printf("service: action %s has been done successfully on %s", action, service.ChosenSystem()) + return nil } // handleServiceStatusCommand handles service "status" command. diff --git a/internal/home/tls.go b/internal/home/tls.go index b9b04eeb..84af6eae 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -172,9 +172,32 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error } }() - tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain) - tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) + err = loadCertificateChainData(tlsConf, status) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + err = loadPrivateKeyData(tlsConf, status) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + err = validateCertificates( + status, + tlsConf.CertificateChainData, + tlsConf.PrivateKeyData, + tlsConf.ServerName, + ) + + return errors.Annotate(err, "validating certificate pair: %w") +} + +// loadCertificateChainData loads PEM-encoded certificates chain data to the +// TLS configuration. +func loadCertificateChainData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { + tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain) if tlsConf.CertificatePath != "" { if tlsConf.CertificateChain != "" { return errors.Error("certificate data and file can't be set together") @@ -190,6 +213,13 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error status.ValidCert = true } + return nil +} + +// loadPrivateKeyData loads PEM-encoded private key data to the TLS +// configuration. +func loadPrivateKeyData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { + tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) if tlsConf.PrivateKeyPath != "" { if tlsConf.PrivateKey != "" { return errors.Error("private key data and file can't be set together") @@ -203,16 +233,6 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error status.ValidKey = true } - err = validateCertificates( - status, - tlsConf.CertificateChainData, - tlsConf.PrivateKeyData, - tlsConf.ServerName, - ) - if err != nil { - return fmt.Errorf("validating certificate pair: %w", err) - } - return nil } diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index 81168561..e429eb41 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -294,71 +294,61 @@ func upgradeSchema4to5(diskConf yobj) error { return nil } -// clients: -// ... +// upgradeSchema5to6 performs the following changes: // -// ip: 127.0.0.1 -// mac: ... +// # BEFORE: +// 'clients': +// ... +// 'ip': 127.0.0.1 +// 'mac': ... // -// -> -// -// clients: -// ... -// -// ids: -// - 127.0.0.1 -// - ... +// # AFTER: +// 'clients': +// ... +// 'ids': +// - 127.0.0.1 +// - ... func upgradeSchema5to6(diskConf yobj) error { - log.Printf("%s(): called", funcName()) - + log.Printf("Upgrade yaml: 5 to 6") diskConf["schema_version"] = 6 - clients, ok := diskConf["clients"] + clientsVal, ok := diskConf["clients"] if !ok { return nil } - switch arr := clients.(type) { - case []any: - for i := range arr { - switch c := arr[i].(type) { - case map[any]any: - var ipVal any - ipVal, ok = c["ip"] - ids := []string{} - if ok { - var ip string - ip, ok = ipVal.(string) - if !ok { - log.Fatalf("client.ip is not a string: %v", ipVal) - return nil - } - if len(ip) != 0 { - ids = append(ids, ip) - } - } + clients, ok := clientsVal.([]yobj) + if !ok { + return fmt.Errorf("unexpected type of clients: %T", clientsVal) + } - var macVal any - macVal, ok = c["mac"] - if ok { - var mac string - mac, ok = macVal.(string) - if !ok { - log.Fatalf("client.mac is not a string: %v", macVal) - return nil - } - if len(mac) != 0 { - ids = append(ids, mac) - } - } + for i := range clients { + c := clients[i] + var ids []string - c["ids"] = ids - default: - continue + if ipVal, hasIP := c["ip"]; hasIP { + var ip string + if ip, ok = ipVal.(string); !ok { + return fmt.Errorf("client.ip is not a string: %v", ipVal) + } + + if ip != "" { + ids = append(ids, ip) } } - default: - return nil + + if macVal, hasMac := c["mac"]; hasMac { + var mac string + if mac, ok = macVal.(string); !ok { + return fmt.Errorf("client.mac is not a string: %v", macVal) + } + + if mac != "" { + ids = append(ids, mac) + } + } + + c["ids"] = ids } return nil diff --git a/internal/home/upgrade_test.go b/internal/home/upgrade_test.go index f4091e84..11820be0 100644 --- a/internal/home/upgrade_test.go +++ b/internal/home/upgrade_test.go @@ -68,6 +68,95 @@ func TestUpgradeSchema2to3(t *testing.T) { assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries) } +func TestUpgradeSchema5to6(t *testing.T) { + const newSchemaVer = 6 + + testCases := []struct { + in yobj + want yobj + wantErr string + name string + }{{ + in: yobj{ + "clients": []yobj{}, + }, + want: yobj{ + "clients": []yobj{}, + "schema_version": newSchemaVer, + }, + wantErr: "", + name: "no_clients", + }, { + in: yobj{ + "clients": []yobj{{"ip": "127.0.0.1"}}, + }, + want: yobj{ + "clients": []yobj{{ + "ids": []string{"127.0.0.1"}, + "ip": "127.0.0.1", + }}, + "schema_version": newSchemaVer, + }, + wantErr: "", + name: "client_ip", + }, { + in: yobj{ + "clients": []yobj{{"mac": "mac"}}, + }, + want: yobj{ + "clients": []yobj{{ + "ids": []string{"mac"}, + "mac": "mac", + }}, + "schema_version": newSchemaVer, + }, + wantErr: "", + name: "client_mac", + }, { + in: yobj{ + "clients": []yobj{{"ip": "127.0.0.1", "mac": "mac"}}, + }, + want: yobj{ + "clients": []yobj{{ + "ids": []string{"127.0.0.1", "mac"}, + "ip": "127.0.0.1", + "mac": "mac", + }}, + "schema_version": newSchemaVer, + }, + wantErr: "", + name: "client_ip_mac", + }, { + in: yobj{ + "clients": []yobj{{"ip": 1, "mac": "mac"}}, + }, + want: yobj{ + "clients": []yobj{{"ip": 1, "mac": "mac"}}, + "schema_version": newSchemaVer, + }, + wantErr: "client.ip is not a string: 1", + name: "inv_client_ip", + }, { + in: yobj{ + "clients": []yobj{{"ip": "127.0.0.1", "mac": 1}}, + }, + want: yobj{ + "clients": []yobj{{"ip": "127.0.0.1", "mac": 1}}, + "schema_version": newSchemaVer, + }, + wantErr: "client.mac is not a string: 1", + name: "inv_client_mac", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := upgradeSchema5to6(tc.in) + testutil.AssertErrorMsg(t, tc.wantErr, err) + assert.Equal(t, tc.want, tc.in) + }) + } +} + func TestUpgradeSchema7to8(t *testing.T) { const host = "1.2.3.4" oldConf := yobj{ diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 78a598e1..aba2521f 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -161,11 +161,8 @@ run_linter "$GO" vet ./... run_linter govulncheck ./... # Apply more lax standards to the code we haven't properly refactored yet. -run_linter gocyclo --over 13\ - ./internal/dhcpd\ - ./internal/home/\ - ./internal/querylog/\ - ; +run_linter gocyclo --over 13 ./internal/querylog +run_linter gocyclo --over 12 ./internal/dhcpd # Apply the normal standards to new or somewhat refactored code. run_linter gocyclo --over 10\ @@ -175,6 +172,7 @@ run_linter gocyclo --over 10\ ./internal/aghtest/\ ./internal/dnsforward/\ ./internal/filtering/\ + ./internal/home/\ ./internal/stats/\ ./internal/tools/\ ./internal/updater/\