Pull request 2303: AGDNS-2505-upd-next
Squashed commit of the following: commit586b0eb180
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Nov 12 19:58:56 2024 +0300 next: upd more commitd729aa150f
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Nov 12 16:53:15 2024 +0300 next/websvc: upd more commit0c64e6cfc6
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Nov 11 21:08:51 2024 +0300 next: upd more commit05eec75222
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Nov 8 19:20:02 2024 +0300 next: upd code
This commit is contained in:
parent
ac5a96fada
commit
1d6d85cff4
3
go.mod
3
go.mod
|
@ -4,14 +4,13 @@ go 1.23.3
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.73.3
|
github.com/AdguardTeam/dnsproxy v0.73.3
|
||||||
github.com/AdguardTeam/golibs v0.30.2
|
github.com/AdguardTeam/golibs v0.30.3
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0
|
github.com/AdguardTeam/urlfilter v0.20.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.3.0
|
github.com/ameshkov/dnscrypt/v2 v2.3.0
|
||||||
github.com/bluele/gcache v0.0.2
|
github.com/bluele/gcache v0.0.2
|
||||||
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
|
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
|
||||||
github.com/digineo/go-ipset/v2 v2.2.1
|
github.com/digineo/go-ipset/v2 v2.2.1
|
||||||
github.com/dimfeld/httptreemux/v5 v5.5.0
|
|
||||||
github.com/fsnotify/fsnotify v1.8.0
|
github.com/fsnotify/fsnotify v1.8.0
|
||||||
github.com/go-ping/ping v1.1.0
|
github.com/go-ping/ping v1.1.0
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -1,7 +1,7 @@
|
||||||
github.com/AdguardTeam/dnsproxy v0.73.3 h1:aacr6Wu0ed94DDD+gSB6EwF8nvyq0+DAc7oFOgtgUpA=
|
github.com/AdguardTeam/dnsproxy v0.73.3 h1:aacr6Wu0ed94DDD+gSB6EwF8nvyq0+DAc7oFOgtgUpA=
|
||||||
github.com/AdguardTeam/dnsproxy v0.73.3/go.mod h1:18ssqhDgOCiVIwYmmVuXVM05wSwrzkO2yjKhVRWJX/g=
|
github.com/AdguardTeam/dnsproxy v0.73.3/go.mod h1:18ssqhDgOCiVIwYmmVuXVM05wSwrzkO2yjKhVRWJX/g=
|
||||||
github.com/AdguardTeam/golibs v0.30.2 h1:urU/NAyIvQOeArBqDmKCDpaRkfTCJ26uSiSuDMKQfuY=
|
github.com/AdguardTeam/golibs v0.30.3 h1:pRxLjMCJ1cZccjZWMMuKxzQQGEpFbmtyj4Tg7nk5rY0=
|
||||||
github.com/AdguardTeam/golibs v0.30.2/go.mod h1:FkwcNQEJoGsgDGXcalrVa/4gWbE68KsmE2guXWtBQUE=
|
github.com/AdguardTeam/golibs v0.30.3/go.mod h1:Ir9dlHfb8nRQsG3Qgo1zoGL+k1qMbcBtb8tcnsvzdAE=
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
||||||
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
||||||
|
@ -25,8 +25,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/digineo/go-ipset/v2 v2.2.1 h1:k6skY+0fMqeUjjeWO/m5OuWPSZUAn7AucHMnQ1MX77g=
|
github.com/digineo/go-ipset/v2 v2.2.1 h1:k6skY+0fMqeUjjeWO/m5OuWPSZUAn7AucHMnQ1MX77g=
|
||||||
github.com/digineo/go-ipset/v2 v2.2.1/go.mod h1:wBsNzJlZlABHUITkesrggFnZQtgW5wkqw1uo8Qxe0VU=
|
github.com/digineo/go-ipset/v2 v2.2.1/go.mod h1:wBsNzJlZlABHUITkesrggFnZQtgW5wkqw1uo8Qxe0VU=
|
||||||
github.com/dimfeld/httptreemux/v5 v5.5.0 h1:p8jkiMrCuZ0CmhwYLcbNbl7DDo21fozhKHQ2PccwOFQ=
|
|
||||||
github.com/dimfeld/httptreemux/v5 v5.5.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw=
|
|
||||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
|
|
|
@ -146,16 +146,6 @@ func IsOpenWrt() (ok bool) {
|
||||||
return isOpenWrt()
|
return isOpenWrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
|
|
||||||
func NotifyReconfigureSignal(c chan<- os.Signal) {
|
|
||||||
notifyReconfigureSignal(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsReconfigureSignal returns true if sig is a reconfigure signal.
|
|
||||||
func IsReconfigureSignal(sig os.Signal) (ok bool) {
|
|
||||||
return isReconfigureSignal(sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendShutdownSignal sends the shutdown signal to the channel.
|
// SendShutdownSignal sends the shutdown signal to the channel.
|
||||||
func SendShutdownSignal(c chan<- os.Signal) {
|
func SendShutdownSignal(c chan<- os.Signal) {
|
||||||
sendShutdownSignal(c)
|
sendShutdownSignal(c)
|
||||||
|
|
|
@ -1,22 +1,11 @@
|
||||||
//go:build darwin || freebsd || linux || openbsd
|
//go:build unix
|
||||||
|
|
||||||
package aghos
|
package aghos
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
|
||||||
signal.Notify(c, unix.SIGHUP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
|
||||||
return sig == unix.SIGHUP
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendShutdownSignal(_ chan<- os.Signal) {
|
func sendShutdownSignal(_ chan<- os.Signal) {
|
||||||
// On Unix we are already notified by the system.
|
// On Unix we are already notified by the system.
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,11 @@ package aghos
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setRlimit(val uint64) (err error) {
|
func setRlimit(_ uint64) (err error) {
|
||||||
return Unsupported("setrlimit")
|
return Unsupported("setrlimit")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,14 +37,6 @@ func isOpenWrt() (ok bool) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
|
||||||
signal.Notify(c, windows.SIGHUP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isReconfigureSignal(sig os.Signal) (ok bool) {
|
|
||||||
return sig == windows.SIGHUP
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendShutdownSignal(c chan<- os.Signal) {
|
func sendShutdownSignal(c chan<- os.Signal) {
|
||||||
c <- os.Interrupt
|
c <- os.Interrupt
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ func (w *FSWatcher) Add(name string) (err error) {
|
||||||
|
|
||||||
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
||||||
type ServiceWithConfig[ConfigType any] struct {
|
type ServiceWithConfig[ConfigType any] struct {
|
||||||
OnStart func() (err error)
|
OnStart func(ctx context.Context) (err error)
|
||||||
OnShutdown func(ctx context.Context) (err error)
|
OnShutdown func(ctx context.Context) (err error)
|
||||||
OnConfig func() (c ConfigType)
|
OnConfig func() (c ConfigType)
|
||||||
}
|
}
|
||||||
|
@ -68,8 +68,8 @@ var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
|
||||||
|
|
||||||
// Start implements the [agh.ServiceWithConfig] interface for
|
// Start implements the [agh.ServiceWithConfig] interface for
|
||||||
// *ServiceWithConfig.
|
// *ServiceWithConfig.
|
||||||
func (s *ServiceWithConfig[_]) Start() (err error) {
|
func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) {
|
||||||
return s.OnStart()
|
return s.OnStart(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown implements the [agh.ServiceWithConfig] interface for
|
// Shutdown implements the [agh.ServiceWithConfig] interface for
|
||||||
|
|
|
@ -82,7 +82,7 @@ type Empty struct{}
|
||||||
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
||||||
|
|
||||||
// Start implements the [Service] interface for Empty.
|
// Start implements the [Service] interface for Empty.
|
||||||
func (Empty) Start() (err error) { return nil }
|
func (Empty) Start(_ context.Context) (err error) { return nil }
|
||||||
|
|
||||||
// Shutdown implements the [Service] interface for Empty.
|
// Shutdown implements the [Service] interface for Empty.
|
||||||
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
|
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
|
||||||
|
|
|
@ -1,36 +1,9 @@
|
||||||
// Package agh contains common entities and interfaces of AdGuard Home.
|
// Package agh contains common entities and interfaces of AdGuard Home.
|
||||||
package agh
|
package agh
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"github.com/AdguardTeam/golibs/service"
|
||||||
// Service is the interface for API servers.
|
)
|
||||||
//
|
|
||||||
// TODO(a.garipov): Consider adding a context to Start.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Consider adding a Wait method or making an extension
|
|
||||||
// interface for that.
|
|
||||||
type Service interface {
|
|
||||||
// Start starts the service. It does not block.
|
|
||||||
Start() (err error)
|
|
||||||
|
|
||||||
// Shutdown gracefully stops the service. ctx is used to determine
|
|
||||||
// a timeout before trying to stop the service less gracefully.
|
|
||||||
Shutdown(ctx context.Context) (err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ Service = EmptyService{}
|
|
||||||
|
|
||||||
// EmptyService is a [Service] that does nothing.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Remove if unnecessary.
|
|
||||||
type EmptyService struct{}
|
|
||||||
|
|
||||||
// Start implements the [Service] interface for EmptyService.
|
|
||||||
func (EmptyService) Start() (err error) { return nil }
|
|
||||||
|
|
||||||
// Shutdown implements the [Service] interface for EmptyService.
|
|
||||||
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
|
|
||||||
|
|
||||||
// ServiceWithConfig is an extension of the [Service] interface for services
|
// ServiceWithConfig is an extension of the [Service] interface for services
|
||||||
// that can return their configuration.
|
// that can return their configuration.
|
||||||
|
@ -38,7 +11,7 @@ func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
|
||||||
// TODO(a.garipov): Consider removing this generic interface if we figure out
|
// TODO(a.garipov): Consider removing this generic interface if we figure out
|
||||||
// how to make it testable in a better way.
|
// how to make it testable in a better way.
|
||||||
type ServiceWithConfig[ConfigType any] interface {
|
type ServiceWithConfig[ConfigType any] interface {
|
||||||
Service
|
service.Interface
|
||||||
|
|
||||||
Config() (c ConfigType)
|
Config() (c ConfigType)
|
||||||
}
|
}
|
||||||
|
@ -51,7 +24,7 @@ var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Remove if unnecessary.
|
// TODO(a.garipov): Remove if unnecessary.
|
||||||
type EmptyServiceWithConfig[ConfigType any] struct {
|
type EmptyServiceWithConfig[ConfigType any] struct {
|
||||||
EmptyService
|
service.Empty
|
||||||
|
|
||||||
Conf ConfigType
|
Conf ConfigType
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,11 +12,15 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Main is the entry point of AdGuard Home.
|
// Main is the entry point of AdGuard Home.
|
||||||
func Main(embeddedFrontend fs.FS) {
|
func Main(embeddedFrontend fs.FS) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
cmdName := os.Args[0]
|
cmdName := os.Args[0]
|
||||||
|
@ -26,70 +30,69 @@ func Main(embeddedFrontend fs.FS) {
|
||||||
os.Exit(exitCode)
|
os.Exit(exitCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = setLog(opts)
|
baseLogger := newBaseLogger(opts)
|
||||||
check(err)
|
|
||||||
|
|
||||||
log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid())
|
baseLogger.InfoContext(
|
||||||
|
ctx,
|
||||||
|
"starting adguard home",
|
||||||
|
"version", version.Version(),
|
||||||
|
"pid", os.Getpid(),
|
||||||
|
)
|
||||||
|
|
||||||
if opts.workDir != "" {
|
if opts.workDir != "" {
|
||||||
log.Info("changing working directory to %q", opts.workDir)
|
baseLogger.InfoContext(ctx, "changing working directory", "dir", opts.workDir)
|
||||||
|
|
||||||
err = os.Chdir(opts.workDir)
|
err = os.Chdir(opts.workDir)
|
||||||
check(err)
|
errors.Check(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
frontend, err := frontendFromOpts(opts, embeddedFrontend)
|
frontend, err := frontendFromOpts(ctx, baseLogger, opts, embeddedFrontend)
|
||||||
check(err)
|
errors.Check(err)
|
||||||
|
|
||||||
|
startCtx, startCancel := context.WithTimeout(ctx, defaultTimeoutStart)
|
||||||
|
defer startCancel()
|
||||||
|
|
||||||
confMgrConf := &configmgr.Config{
|
confMgrConf := &configmgr.Config{
|
||||||
Frontend: frontend,
|
BaseLogger: baseLogger,
|
||||||
WebAddr: opts.webAddr,
|
Logger: baseLogger.With(slogutil.KeyPrefix, "configmgr"),
|
||||||
Start: start,
|
Frontend: frontend,
|
||||||
FileName: opts.confFile,
|
WebAddr: opts.webAddr,
|
||||||
|
Start: start,
|
||||||
|
FileName: opts.confFile,
|
||||||
}
|
}
|
||||||
|
|
||||||
confMgr, err := newConfigMgr(confMgrConf)
|
confMgr, err := configmgr.New(startCtx, confMgrConf)
|
||||||
check(err)
|
errors.Check(err)
|
||||||
|
|
||||||
web := confMgr.Web()
|
web := confMgr.Web()
|
||||||
err = web.Start()
|
err = web.Start(startCtx)
|
||||||
check(err)
|
errors.Check(err)
|
||||||
|
|
||||||
dns := confMgr.DNS()
|
dns := confMgr.DNS()
|
||||||
err = dns.Start()
|
err = dns.Start(startCtx)
|
||||||
check(err)
|
errors.Check(err)
|
||||||
|
|
||||||
sigHdlr := newSignalHandler(
|
sigHdlr := newSignalHandler(
|
||||||
|
baseLogger.With(slogutil.KeyPrefix, service.SignalHandlerPrefix),
|
||||||
confMgrConf,
|
confMgrConf,
|
||||||
opts.pidFile,
|
opts.pidFile,
|
||||||
web,
|
web,
|
||||||
dns,
|
dns,
|
||||||
)
|
)
|
||||||
|
|
||||||
sigHdlr.handle()
|
os.Exit(sigHdlr.handle(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultTimeout is the timeout used for some operations where another timeout
|
// Default timeouts.
|
||||||
// hasn't been defined yet.
|
//
|
||||||
const defaultTimeout = 5 * time.Second
|
// TODO(a.garipov): Make configurable.
|
||||||
|
const (
|
||||||
// ctxWithDefaultTimeout is a helper function that returns a context with
|
defaultTimeoutStart = 1 * time.Minute
|
||||||
// timeout set to defaultTimeout.
|
defaultTimeoutShutdown = 5 * time.Second
|
||||||
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
|
)
|
||||||
return context.WithTimeout(context.Background(), defaultTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newConfigMgr returns a new configuration manager using defaultTimeout as the
|
// newConfigMgr returns a new configuration manager using defaultTimeout as the
|
||||||
// context timeout.
|
// context timeout.
|
||||||
func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) {
|
func newConfigMgr(ctx context.Context, c *configmgr.Config) (m *configmgr.Manager, err error) {
|
||||||
ctx, cancel := ctxWithDefaultTimeout()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
return configmgr.New(ctx, c)
|
return configmgr.New(ctx, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check is a simple error-checking helper. It must only be used within Main.
|
|
||||||
func check(err error) {
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,39 +1,39 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// syslogServiceName is the name of the AdGuard Home service used for writing
|
// newBaseLogger constructs a base logger based on the command-line options.
|
||||||
// logs to the system log.
|
// opts must not be nil.
|
||||||
const syslogServiceName = "AdGuardHome"
|
func newBaseLogger(opts *options) (baseLogger *slog.Logger) {
|
||||||
|
var output io.Writer
|
||||||
// setLog sets up the text logging.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Add parameters from configuration file.
|
|
||||||
func setLog(opts *options) (err error) {
|
|
||||||
switch opts.confFile {
|
switch opts.confFile {
|
||||||
case "stdout":
|
case "stdout":
|
||||||
log.SetOutput(os.Stdout)
|
output = os.Stdout
|
||||||
case "stderr":
|
case "stderr":
|
||||||
log.SetOutput(os.Stderr)
|
output = os.Stderr
|
||||||
case "syslog":
|
case "syslog":
|
||||||
err = aghos.ConfigureSyslog(syslogServiceName)
|
// TODO(a.garipov): Add a syslog handler to golibs.
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("initializing syslog: %w", err)
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
// TODO(a.garipov): Use the path.
|
// TODO(a.garipov): Use the path.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lvl := slog.LevelInfo
|
||||||
if opts.verbose {
|
if opts.verbose {
|
||||||
log.SetLevel(log.DEBUG)
|
lvl = slog.LevelDebug
|
||||||
log.Debug("verbose logging enabled")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return slogutil.New(&slogutil.Config{
|
||||||
|
Output: output,
|
||||||
|
// TODO(a.garipov): Get from config?
|
||||||
|
Format: slogutil.FormatText,
|
||||||
|
Level: lvl,
|
||||||
|
// TODO(a.garipov): Get from config.
|
||||||
|
AddTimestamp: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding"
|
"encoding"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
|
@ -14,7 +16,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
||||||
|
@ -372,13 +374,13 @@ func processOptions(
|
||||||
) (exitCode int, needExit bool) {
|
) (exitCode int, needExit bool) {
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
// Assume that usage has already been printed.
|
// Assume that usage has already been printed.
|
||||||
return statusArgumentError, true
|
return osutil.ExitCodeArgumentError, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.help {
|
if opts.help {
|
||||||
usage(cmdName, os.Stdout)
|
usage(cmdName, os.Stdout)
|
||||||
|
|
||||||
return statusSuccess, true
|
return osutil.ExitCodeSuccess, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.version {
|
if opts.version {
|
||||||
|
@ -388,7 +390,7 @@ func processOptions(
|
||||||
fmt.Printf("AdGuard Home %s\n", version.Version())
|
fmt.Printf("AdGuard Home %s\n", version.Version())
|
||||||
}
|
}
|
||||||
|
|
||||||
return statusSuccess, true
|
return osutil.ExitCodeSuccess, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.checkConfig {
|
if opts.checkConfig {
|
||||||
|
@ -396,21 +398,26 @@ func processOptions(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = io.WriteString(os.Stdout, err.Error()+"\n")
|
_, _ = io.WriteString(os.Stdout, err.Error()+"\n")
|
||||||
|
|
||||||
return statusError, true
|
return osutil.ExitCodeFailure, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return statusSuccess, true
|
return osutil.ExitCodeSuccess, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// frontendFromOpts returns the frontend to use based on the options.
|
// frontendFromOpts returns the frontend to use based on the options.
|
||||||
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) {
|
func frontendFromOpts(
|
||||||
|
ctx context.Context,
|
||||||
|
logger *slog.Logger,
|
||||||
|
opts *options,
|
||||||
|
embeddedFrontend fs.FS,
|
||||||
|
) (frontend fs.FS, err error) {
|
||||||
const frontendSubdir = "build/static"
|
const frontendSubdir = "build/static"
|
||||||
|
|
||||||
if opts.localFrontend {
|
if opts.localFrontend {
|
||||||
log.Info("warning: using local frontend files")
|
logger.WarnContext(ctx, "using local frontend files")
|
||||||
|
|
||||||
return os.DirFS(frontendSubdir), nil
|
return os.DirFS(frontendSubdir), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,26 @@
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/osutil"
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
|
"github.com/AdguardTeam/golibs/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
// signalHandler processes incoming signals and shuts services down.
|
// signalHandler processes incoming signals and shuts services down.
|
||||||
type signalHandler struct {
|
type signalHandler struct {
|
||||||
|
// logger is used for logging the operation of the signal handler.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
// confMgrConf contains the configuration parameters for the configuration
|
// confMgrConf contains the configuration parameters for the configuration
|
||||||
// manager.
|
// manager.
|
||||||
confMgrConf *configmgr.Config
|
confMgrConf *configmgr.Config
|
||||||
|
@ -24,145 +32,172 @@ type signalHandler struct {
|
||||||
pidFile string
|
pidFile string
|
||||||
|
|
||||||
// services are the services that are shut down before application exiting.
|
// services are the services that are shut down before application exiting.
|
||||||
services []agh.Service
|
services []service.Interface
|
||||||
|
|
||||||
|
// shutdownTimeout is the timeout for the shutdown operation.
|
||||||
|
shutdownTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle processes OS signals.
|
// handle processes OS signals. It blocks until a termination or a
|
||||||
func (h *signalHandler) handle() {
|
// reconfiguration signal is received, after which it either shuts down all
|
||||||
defer log.OnPanic("signalHandler.handle")
|
// services or reconfigures them. ctx is used for logging and serves as the
|
||||||
|
// base for the shutdown timeout. status is [osutil.ExitCodeSuccess] on success
|
||||||
|
// and [osutil.ExitCodeFailure] on error.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Add reconfiguration logic to golibs.
|
||||||
|
func (h *signalHandler) handle(ctx context.Context) (status osutil.ExitCode) {
|
||||||
|
defer slogutil.RecoverAndLog(ctx, h.logger)
|
||||||
|
|
||||||
h.writePID()
|
h.writePID(ctx)
|
||||||
|
|
||||||
for sig := range h.signal {
|
for sig := range h.signal {
|
||||||
log.Info("sighdlr: received signal %q", sig)
|
h.logger.InfoContext(ctx, "received", "signal", sig)
|
||||||
|
|
||||||
if aghos.IsReconfigureSignal(sig) {
|
if osutil.IsReconfigureSignal(sig) {
|
||||||
h.reconfigure()
|
err := h.reconfigure(ctx)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.ErrorContext(ctx, "reconfiguration error", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
return osutil.ExitCodeFailure
|
||||||
|
}
|
||||||
} else if osutil.IsShutdownSignal(sig) {
|
} else if osutil.IsShutdownSignal(sig) {
|
||||||
status := h.shutdown()
|
status = h.shutdown(ctx)
|
||||||
h.removePID()
|
|
||||||
|
|
||||||
log.Info("sighdlr: exiting with status %d", status)
|
h.removePID(ctx)
|
||||||
|
|
||||||
os.Exit(status)
|
return status
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shouldn't happen, since h.signal is currently never closed.
|
||||||
|
panic("unexpected close of h.signal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// writePID writes the PID to the file, if needed. Any errors are reported to
|
||||||
|
// log.
|
||||||
|
func (h *signalHandler) writePID(ctx context.Context) {
|
||||||
|
if h.pidFile == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pid := os.Getpid()
|
||||||
|
data := strconv.AppendInt(nil, int64(pid), 10)
|
||||||
|
data = append(data, '\n')
|
||||||
|
|
||||||
|
err := aghos.WriteFile(h.pidFile, data, 0o644)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.ErrorContext(ctx, "writing pidfile", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.DebugContext(ctx, "wrote pid", "file", h.pidFile, "pid", pid)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconfigure rereads the configuration file and updates and restarts services.
|
// reconfigure rereads the configuration file and updates and restarts services.
|
||||||
func (h *signalHandler) reconfigure() {
|
func (h *signalHandler) reconfigure(ctx context.Context) (err error) {
|
||||||
log.Info("sighdlr: reconfiguring adguard home")
|
h.logger.InfoContext(ctx, "reconfiguring started")
|
||||||
|
|
||||||
status := h.shutdown()
|
status := h.shutdown(ctx)
|
||||||
if status != statusSuccess {
|
if status != osutil.ExitCodeSuccess {
|
||||||
log.Info("sighdlr: reconfiguring: exiting with status %d", status)
|
return errors.Error("shutdown failed")
|
||||||
|
|
||||||
os.Exit(status)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(a.garipov): This is a very rough way to do it. Some services can be
|
// TODO(a.garipov): This is a very rough way to do it. Some services can
|
||||||
// reconfigured without the full shutdown, and the error handling is
|
// be reconfigured without the full shutdown, and the error handling is
|
||||||
// currently not the best.
|
// currently not the best.
|
||||||
|
|
||||||
confMgr, err := newConfigMgr(h.confMgrConf)
|
var errs []error
|
||||||
check(err)
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultTimeoutStart)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
confMgr, err := newConfigMgr(ctx, h.confMgrConf)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("configuration manager: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
web := confMgr.Web()
|
web := confMgr.Web()
|
||||||
err = web.Start()
|
err = web.Start(ctx)
|
||||||
check(err)
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("starting web: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
dns := confMgr.DNS()
|
dns := confMgr.DNS()
|
||||||
err = dns.Start()
|
err = dns.Start(ctx)
|
||||||
check(err)
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("starting dns: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
h.services = []agh.Service{
|
if len(errs) > 0 {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.services = []service.Interface{
|
||||||
dns,
|
dns,
|
||||||
web,
|
web,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("sighdlr: successfully reconfigured adguard home")
|
h.logger.InfoContext(ctx, "reconfiguring finished")
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exit status constants.
|
|
||||||
const (
|
|
||||||
statusSuccess = 0
|
|
||||||
statusError = 1
|
|
||||||
statusArgumentError = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// shutdown gracefully shuts down all services.
|
// shutdown gracefully shuts down all services.
|
||||||
func (h *signalHandler) shutdown() (status int) {
|
func (h *signalHandler) shutdown(ctx context.Context) (status int) {
|
||||||
ctx, cancel := ctxWithDefaultTimeout()
|
ctx, cancel := context.WithTimeout(ctx, h.shutdownTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
status = statusSuccess
|
status = osutil.ExitCodeSuccess
|
||||||
|
|
||||||
log.Info("sighdlr: shutting down services")
|
h.logger.InfoContext(ctx, "shutting down")
|
||||||
for i, service := range h.services {
|
for i, svc := range h.services {
|
||||||
err := service.Shutdown(ctx)
|
err := svc.Shutdown(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
|
h.logger.ErrorContext(ctx, "shutting down service", "idx", i, slogutil.KeyError, err)
|
||||||
status = statusError
|
status = osutil.ExitCodeFailure
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return status
|
return status
|
||||||
}
|
}
|
||||||
|
|
||||||
// newSignalHandler returns a new signalHandler that shuts down svcs.
|
// newSignalHandler returns a new signalHandler that shuts down svcs. logger
|
||||||
|
// and confMgrConf must not be nil.
|
||||||
func newSignalHandler(
|
func newSignalHandler(
|
||||||
|
logger *slog.Logger,
|
||||||
confMgrConf *configmgr.Config,
|
confMgrConf *configmgr.Config,
|
||||||
pidFile string,
|
pidFile string,
|
||||||
svcs ...agh.Service,
|
svcs ...service.Interface,
|
||||||
) (h *signalHandler) {
|
) (h *signalHandler) {
|
||||||
h = &signalHandler{
|
h = &signalHandler{
|
||||||
confMgrConf: confMgrConf,
|
logger: logger,
|
||||||
signal: make(chan os.Signal, 1),
|
confMgrConf: confMgrConf,
|
||||||
pidFile: pidFile,
|
signal: make(chan os.Signal, 1),
|
||||||
services: svcs,
|
pidFile: pidFile,
|
||||||
|
services: svcs,
|
||||||
|
shutdownTimeout: defaultTimeoutShutdown,
|
||||||
}
|
}
|
||||||
|
|
||||||
notifier := osutil.DefaultSignalNotifier{}
|
notifier := osutil.DefaultSignalNotifier{}
|
||||||
osutil.NotifyShutdownSignal(notifier, h.signal)
|
osutil.NotifyShutdownSignal(notifier, h.signal)
|
||||||
aghos.NotifyReconfigureSignal(h.signal)
|
osutil.NotifyReconfigureSignal(notifier, h.signal)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// writePID writes the PID to the file, if needed. Any errors are reported to
|
|
||||||
// log.
|
|
||||||
func (h *signalHandler) writePID() {
|
|
||||||
if h.pidFile == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use 8, since most PIDs will fit.
|
|
||||||
data := make([]byte, 0, 8)
|
|
||||||
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
|
|
||||||
data = append(data, '\n')
|
|
||||||
|
|
||||||
err := aghos.WriteFile(h.pidFile, data, 0o644)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("sighdlr: writing pidfile: %s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removePID removes the PID file, if any.
|
// removePID removes the PID file, if any.
|
||||||
func (h *signalHandler) removePID() {
|
func (h *signalHandler) removePID(ctx context.Context) {
|
||||||
if h.pidFile == "" {
|
if h.pidFile == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := os.Remove(h.pidFile)
|
err := os.Remove(h.pidFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("sighdlr: removing pidfile: %s", err)
|
h.logger.ErrorContext(ctx, "removing pidfile", slogutil.KeyError, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("sighdlr: removed pid at %q", h.pidFile)
|
h.logger.DebugContext(ctx, "removed pidfile", "file", h.pidFile)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,12 +4,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/container"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Configuration Structures
|
|
||||||
|
|
||||||
// config is the top-level on-disk configuration structure.
|
// config is the top-level on-disk configuration structure.
|
||||||
type config struct {
|
type config struct {
|
||||||
DNS *dnsConfig `yaml:"dns"`
|
DNS *dnsConfig `yaml:"dns"`
|
||||||
|
@ -19,35 +18,33 @@ type config struct {
|
||||||
SchemaVersion int `yaml:"schema_version"`
|
SchemaVersion int `yaml:"schema_version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const errNoConf errors.Error = "configuration not found"
|
// type check
|
||||||
|
var _ validator = (*config)(nil)
|
||||||
|
|
||||||
// validate returns an error if the configuration structure is invalid.
|
// validate implements the [validator] interface for *config.
|
||||||
func (c *config) validate() (err error) {
|
func (c *config) validate() (err error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errNoConf
|
return errors.ErrNoValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(a.garipov): Add more validations.
|
// TODO(a.garipov): Add more validations.
|
||||||
|
|
||||||
// Keep this in the same order as the fields in the config.
|
// Keep this in the same order as the fields in the config.
|
||||||
validators := []struct {
|
validators := container.KeyValues[string, validator]{{
|
||||||
validate func() (err error)
|
Key: "dns",
|
||||||
name string
|
Value: c.DNS,
|
||||||
}{{
|
|
||||||
validate: c.DNS.validate,
|
|
||||||
name: "dns",
|
|
||||||
}, {
|
}, {
|
||||||
validate: c.HTTP.validate,
|
Key: "http",
|
||||||
name: "http",
|
Value: c.HTTP,
|
||||||
}, {
|
}, {
|
||||||
validate: c.Log.validate,
|
Key: "log",
|
||||||
name: "log",
|
Value: c.Log,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, v := range validators {
|
for _, kv := range validators {
|
||||||
err = v.validate()
|
err = kv.Value.validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s: %w", v.name, err)
|
return fmt.Errorf("%s: %w", kv.Key, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,16 +62,19 @@ type dnsConfig struct {
|
||||||
UseDNS64 bool `yaml:"use_dns64"`
|
UseDNS64 bool `yaml:"use_dns64"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if the DNS configuration structure is invalid.
|
// type check
|
||||||
|
var _ validator = (*dnsConfig)(nil)
|
||||||
|
|
||||||
|
// validate implements the [validator] interface for *dnsConfig.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Add more validations.
|
// TODO(a.garipov): Add more validations.
|
||||||
func (c *dnsConfig) validate() (err error) {
|
func (c *dnsConfig) validate() (err error) {
|
||||||
// TODO(a.garipov): Add more validations.
|
// TODO(a.garipov): Add more validations.
|
||||||
switch {
|
switch {
|
||||||
case c == nil:
|
case c == nil:
|
||||||
return errNoConf
|
return errors.ErrNoValue
|
||||||
case c.UpstreamTimeout.Duration <= 0:
|
case c.UpstreamTimeout.Duration <= 0:
|
||||||
return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout)
|
return newErrNotPositive("upstream_timeout", c.UpstreamTimeout)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -91,15 +91,18 @@ type httpConfig struct {
|
||||||
ForceHTTPS bool `yaml:"force_https"`
|
ForceHTTPS bool `yaml:"force_https"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if the HTTP configuration structure is invalid.
|
// type check
|
||||||
|
var _ validator = (*httpConfig)(nil)
|
||||||
|
|
||||||
|
// validate implements the [validator] interface for *httpConfig.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Add more validations.
|
// TODO(a.garipov): Add more validations.
|
||||||
func (c *httpConfig) validate() (err error) {
|
func (c *httpConfig) validate() (err error) {
|
||||||
switch {
|
switch {
|
||||||
case c == nil:
|
case c == nil:
|
||||||
return errNoConf
|
return errors.ErrNoValue
|
||||||
case c.Timeout.Duration <= 0:
|
case c.Timeout.Duration <= 0:
|
||||||
return newMustBePositiveError("timeout", c.Timeout)
|
return newErrNotPositive("timeout", c.Timeout)
|
||||||
default:
|
default:
|
||||||
return c.Pprof.validate()
|
return c.Pprof.validate()
|
||||||
}
|
}
|
||||||
|
@ -111,10 +114,13 @@ type httpPprofConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if the pprof configuration structure is invalid.
|
// type check
|
||||||
|
var _ validator = (*httpPprofConfig)(nil)
|
||||||
|
|
||||||
|
// validate implements the [validator] interface for *httpPprofConfig.
|
||||||
func (c *httpPprofConfig) validate() (err error) {
|
func (c *httpPprofConfig) validate() (err error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errNoConf
|
return errors.ErrNoValue
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -126,12 +132,15 @@ type logConfig struct {
|
||||||
Verbose bool `yaml:"verbose"`
|
Verbose bool `yaml:"verbose"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if the HTTP configuration structure is invalid.
|
// type check
|
||||||
|
var _ validator = (*logConfig)(nil)
|
||||||
|
|
||||||
|
// validate implements the [validator] interface for *logConfig.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Add more validations.
|
// TODO(a.garipov): Add more validations.
|
||||||
func (c *logConfig) validate() (err error) {
|
func (c *logConfig) validate() (err error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errNoConf
|
return errors.ErrNoValue
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
|
@ -19,18 +20,22 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Configuration Manager
|
|
||||||
|
|
||||||
// Manager handles full and partial changes in the configuration, persisting
|
// Manager handles full and partial changes in the configuration, persisting
|
||||||
// them to disk if necessary.
|
// them to disk if necessary.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Support missing configs and default values.
|
// TODO(a.garipov): Support missing configs and default values.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
|
// baseLogger is used to create loggers for other entities.
|
||||||
|
baseLogger *slog.Logger
|
||||||
|
|
||||||
|
// logger is used for logging the operation of the configuration manager.
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
// updMu makes sure that at most one reconfiguration is performed at a time.
|
// updMu makes sure that at most one reconfiguration is performed at a time.
|
||||||
// updMu protects all fields below.
|
// updMu protects all fields below.
|
||||||
updMu *sync.RWMutex
|
updMu *sync.RWMutex
|
||||||
|
@ -57,12 +62,24 @@ func Validate(fileName string) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
err = conf.validate()
|
||||||
return conf.validate()
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains the configuration parameters for the configuration manager.
|
// Config contains the configuration parameters for the configuration manager.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// BaseLogger is used to create loggers for other entities. It must not be
|
||||||
|
// nil.
|
||||||
|
BaseLogger *slog.Logger
|
||||||
|
|
||||||
|
// Logger is used for logging the operation of the configuration manager.
|
||||||
|
// It must not be nil.
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
// Frontend is the filesystem with the frontend files.
|
// Frontend is the filesystem with the frontend files.
|
||||||
Frontend fs.FS
|
Frontend fs.FS
|
||||||
|
|
||||||
|
@ -93,9 +110,11 @@ func New(ctx context.Context, c *Config) (m *Manager, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
m = &Manager{
|
m = &Manager{
|
||||||
updMu: &sync.RWMutex{},
|
baseLogger: c.BaseLogger,
|
||||||
current: conf,
|
logger: c.Logger,
|
||||||
fileName: c.FileName,
|
updMu: &sync.RWMutex{},
|
||||||
|
current: conf,
|
||||||
|
fileName: c.FileName,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
|
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
|
||||||
|
@ -137,6 +156,7 @@ func (m *Manager) assemble(
|
||||||
start time.Time,
|
start time.Time,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
dnsConf := &dnssvc.Config{
|
dnsConf := &dnssvc.Config{
|
||||||
|
Logger: m.baseLogger.With(slogutil.KeyPrefix, "dnssvc"),
|
||||||
Addresses: conf.DNS.Addresses,
|
Addresses: conf.DNS.Addresses,
|
||||||
BootstrapServers: conf.DNS.BootstrapDNS,
|
BootstrapServers: conf.DNS.BootstrapDNS,
|
||||||
UpstreamServers: conf.DNS.UpstreamDNS,
|
UpstreamServers: conf.DNS.UpstreamDNS,
|
||||||
|
@ -151,6 +171,7 @@ func (m *Manager) assemble(
|
||||||
}
|
}
|
||||||
|
|
||||||
webSvcConf := &websvc.Config{
|
webSvcConf := &websvc.Config{
|
||||||
|
Logger: m.baseLogger.With(slogutil.KeyPrefix, "websvc"),
|
||||||
Pprof: &websvc.PprofConfig{
|
Pprof: &websvc.PprofConfig{
|
||||||
Port: conf.HTTP.Pprof.Port,
|
Port: conf.HTTP.Pprof.Port,
|
||||||
Enabled: conf.HTTP.Pprof.Enabled,
|
Enabled: conf.HTTP.Pprof.Enabled,
|
||||||
|
@ -176,7 +197,7 @@ func (m *Manager) assemble(
|
||||||
}
|
}
|
||||||
|
|
||||||
// write writes the current configuration to disk.
|
// write writes the current configuration to disk.
|
||||||
func (m *Manager) write() (err error) {
|
func (m *Manager) write(ctx context.Context) (err error) {
|
||||||
b, err := yaml.Marshal(m.current)
|
b, err := yaml.Marshal(m.current)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("encoding: %w", err)
|
return fmt.Errorf("encoding: %w", err)
|
||||||
|
@ -187,7 +208,7 @@ func (m *Manager) write() (err error) {
|
||||||
return fmt.Errorf("writing: %w", err)
|
return fmt.Errorf("writing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("configmgr: written to %q", m.fileName)
|
m.logger.InfoContext(ctx, "config file written", "path", m.fileName)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -216,7 +237,7 @@ func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||||
|
|
||||||
m.updateCurrentDNS(c)
|
m.updateCurrentDNS(c)
|
||||||
|
|
||||||
return m.write()
|
return m.write(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
|
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
|
||||||
|
@ -270,7 +291,7 @@ func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||||
|
|
||||||
m.updateCurrentWeb(c)
|
m.updateCurrentWeb(c)
|
||||||
|
|
||||||
return m.write()
|
return m.write(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateWeb recreates the web service. m.upd is expected to be locked.
|
// updateWeb recreates the web service. m.upd is expected to be locked.
|
||||||
|
|
|
@ -3,25 +3,29 @@ package configmgr
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"golang.org/x/exp/constraints"
|
"golang.org/x/exp/constraints"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// validator is the interface for configuration entities that can validate
|
||||||
|
// themselves.
|
||||||
|
type validator interface {
|
||||||
|
// validate returns an error if the entity isn't valid.
|
||||||
|
validate() (err error)
|
||||||
|
}
|
||||||
|
|
||||||
// numberOrDuration is the constraint for integer types along with
|
// numberOrDuration is the constraint for integer types along with
|
||||||
// timeutil.Duration.
|
// timeutil.Duration.
|
||||||
type numberOrDuration interface {
|
type numberOrDuration interface {
|
||||||
constraints.Integer | timeutil.Duration
|
constraints.Integer | timeutil.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMustBePositiveError returns an error about the value that must be positive
|
// newErrNotPositive returns an error about the value that must be positive but
|
||||||
// but isn't. prop is the name of the property to mention in the error message.
|
// isn't. prop is the name of the property to mention in the error message.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
|
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
|
||||||
// as well.
|
// as well.
|
||||||
func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) {
|
func newErrNotPositive[T numberOrDuration](prop string, v T) (err error) {
|
||||||
if s, ok := any(v).(fmt.Stringer); ok {
|
return fmt.Errorf("%s: %w, got %v", prop, errors.ErrNotPositive, v)
|
||||||
return fmt.Errorf("%s must be positive, got %s", prop, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("%s must be positive, got %d", prop, v)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package dnssvc
|
package dnssvc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -9,6 +10,10 @@ import (
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Add timeout for incoming requests.
|
// TODO(a.garipov): Add timeout for incoming requests.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// Logger is used for logging the operation of the web API service. It must
|
||||||
|
// not be nil.
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
// Addresses are the addresses on which to serve plain DNS queries.
|
// Addresses are the addresses on which to serve plain DNS queries.
|
||||||
Addresses []netip.AddrPort
|
Addresses []netip.AddrPort
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ package dnssvc
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -28,6 +29,7 @@ import (
|
||||||
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
|
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
|
||||||
// fields that are only used in [New] and [Service.Config].
|
// fields that are only used in [New] and [Service.Config].
|
||||||
type Service struct {
|
type Service struct {
|
||||||
|
logger *slog.Logger
|
||||||
proxy *proxy.Proxy
|
proxy *proxy.Proxy
|
||||||
bootstraps []string
|
bootstraps []string
|
||||||
bootstrapResolvers []*upstream.UpstreamResolver
|
bootstrapResolvers []*upstream.UpstreamResolver
|
||||||
|
@ -48,6 +50,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
svc = &Service{
|
svc = &Service{
|
||||||
|
logger: c.Logger,
|
||||||
bootstraps: c.BootstrapServers,
|
bootstraps: c.BootstrapServers,
|
||||||
upstreams: c.UpstreamServers,
|
upstreams: c.UpstreamServers,
|
||||||
dns64Prefixes: c.DNS64Prefixes,
|
dns64Prefixes: c.DNS64Prefixes,
|
||||||
|
@ -68,6 +71,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
|
|
||||||
svc.bootstrapResolvers = resolvers
|
svc.bootstrapResolvers = resolvers
|
||||||
svc.proxy, err = proxy.New(&proxy.Config{
|
svc.proxy, err = proxy.New(&proxy.Config{
|
||||||
|
Logger: svc.logger,
|
||||||
UDPListenAddr: udpAddrs(c.Addresses),
|
UDPListenAddr: udpAddrs(c.Addresses),
|
||||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||||
UpstreamConfig: &proxy.UpstreamConfig{
|
UpstreamConfig: &proxy.UpstreamConfig{
|
||||||
|
@ -153,12 +157,12 @@ func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ agh.Service = (*Service)(nil)
|
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
|
||||||
|
|
||||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||||
// After Start exits, all DNS servers have tried to start, but there is no
|
// After Start exits, all DNS servers have tried to start, but there is no
|
||||||
// guarantee that they did. Errors from the servers are written to the log.
|
// guarantee that they did. Errors from the servers are written to the log.
|
||||||
func (svc *Service) Start() (err error) {
|
func (svc *Service) Start(ctx context.Context) (err error) {
|
||||||
if svc == nil {
|
if svc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -170,7 +174,7 @@ func (svc *Service) Start() (err error) {
|
||||||
svc.running.Store(err == nil)
|
svc.running.Store(err == nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return svc.proxy.Start(context.Background())
|
return svc.proxy.Start(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||||
|
@ -215,6 +219,7 @@ func (svc *Service) Config() (c *Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c = &Config{
|
c = &Config{
|
||||||
|
Logger: svc.logger,
|
||||||
Addresses: addrs,
|
Addresses: addrs,
|
||||||
BootstrapServers: svc.bootstraps,
|
BootstrapServers: svc.bootstraps,
|
||||||
UpstreamServers: svc.upstreams,
|
UpstreamServers: svc.upstreams,
|
||||||
|
|
|
@ -6,16 +6,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
testutil.DiscardLogOutput(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// testTimeout is the common timeout for tests.
|
// testTimeout is the common timeout for tests.
|
||||||
const testTimeout = 1 * time.Second
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
@ -59,6 +56,7 @@ func TestService(t *testing.T) {
|
||||||
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
|
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
|
||||||
|
|
||||||
c := &dnssvc.Config{
|
c := &dnssvc.Config{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
|
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
|
||||||
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
|
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
|
||||||
UpstreamServers: []string{upstreamAddr},
|
UpstreamServers: []string{upstreamAddr},
|
||||||
|
@ -71,7 +69,7 @@ func TestService(t *testing.T) {
|
||||||
svc, err := dnssvc.New(c)
|
svc, err := dnssvc.New(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = svc.Start()
|
err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
gotConf := svc.Config()
|
gotConf := svc.Config()
|
||||||
|
|
|
@ -3,12 +3,17 @@ package websvc
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the AdGuard Home web service configuration structure.
|
// Config is the AdGuard Home web service configuration structure.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// Logger is used for logging the operation of the web API service. It must
|
||||||
|
// not be nil.
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
// Pprof is the configuration for the pprof debug API. It must not be nil.
|
// Pprof is the configuration for the pprof debug API. It must not be nil.
|
||||||
Pprof *PprofConfig
|
Pprof *PprofConfig
|
||||||
|
|
||||||
|
@ -60,17 +65,20 @@ type PprofConfig struct {
|
||||||
// finished.
|
// finished.
|
||||||
func (svc *Service) Config() (c *Config) {
|
func (svc *Service) Config() (c *Config) {
|
||||||
c = &Config{
|
c = &Config{
|
||||||
|
Logger: svc.logger,
|
||||||
Pprof: &PprofConfig{
|
Pprof: &PprofConfig{
|
||||||
Port: svc.pprofPort,
|
Port: svc.pprofPort,
|
||||||
Enabled: svc.pprof != nil,
|
Enabled: svc.pprof != nil,
|
||||||
},
|
},
|
||||||
ConfigManager: svc.confMgr,
|
ConfigManager: svc.confMgr,
|
||||||
|
Frontend: svc.frontend,
|
||||||
TLS: svc.tls,
|
TLS: svc.tls,
|
||||||
// Leave Addresses and SecureAddresses empty and get the actual
|
// Leave Addresses and SecureAddresses empty and get the actual
|
||||||
// addresses that include the :0 ones later.
|
// addresses that include the :0 ones later.
|
||||||
Start: svc.start,
|
Start: svc.start,
|
||||||
Timeout: svc.timeout,
|
OverrideAddress: svc.overrideAddr,
|
||||||
ForceHTTPS: svc.forceHTTPS,
|
Timeout: svc.timeout,
|
||||||
|
ForceHTTPS: svc.forceHTTPS,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Addresses, c.SecureAddresses = svc.addrs()
|
c.Addresses, c.SecureAddresses = svc.addrs()
|
||||||
|
|
|
@ -11,8 +11,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNS Settings Handlers
|
|
||||||
|
|
||||||
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
|
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
|
||||||
// HTTP API.
|
// HTTP API.
|
||||||
type ReqPatchSettingsDNS struct {
|
type ReqPatchSettingsDNS struct {
|
||||||
|
@ -60,6 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||||
}
|
}
|
||||||
|
|
||||||
newConf := &dnssvc.Config{
|
newConf := &dnssvc.Config{
|
||||||
|
Logger: svc.logger,
|
||||||
Addresses: req.Addresses,
|
Addresses: req.Addresses,
|
||||||
BootstrapServers: req.BootstrapServers,
|
BootstrapServers: req.BootstrapServers,
|
||||||
UpstreamServers: req.UpstreamServers,
|
UpstreamServers: req.UpstreamServers,
|
||||||
|
@ -78,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||||
}
|
}
|
||||||
|
|
||||||
newSvc := svc.confMgr.DNS()
|
newSvc := svc.confMgr.DNS()
|
||||||
err = newSvc.Start()
|
err = newSvc.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
|
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
|
||||||
confMgr := newConfigManager()
|
confMgr := newConfigManager()
|
||||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||||
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
|
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
|
||||||
OnStart: func() (err error) {
|
OnStart: func(_ context.Context) (err error) {
|
||||||
started.Store(true)
|
started.Store(true)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -52,7 +52,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: urlutil.SchemeHTTP,
|
Scheme: urlutil.SchemeHTTP,
|
||||||
Host: addr.String(),
|
Host: addr.String(),
|
||||||
Path: websvc.PathV1SettingsDNS,
|
Path: websvc.PathPatternV1SettingsDNS,
|
||||||
}
|
}
|
||||||
|
|
||||||
req := jobj{
|
req := jobj{
|
||||||
|
|
|
@ -10,11 +10,9 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTP Settings Handlers
|
|
||||||
|
|
||||||
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
|
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
|
||||||
// HTTP API.
|
// HTTP API.
|
||||||
type ReqPatchSettingsHTTP struct {
|
type ReqPatchSettingsHTTP struct {
|
||||||
|
@ -53,6 +51,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||||
}
|
}
|
||||||
|
|
||||||
newConf := &Config{
|
newConf := &Config{
|
||||||
|
Logger: svc.logger,
|
||||||
Pprof: &PprofConfig{
|
Pprof: &PprofConfig{
|
||||||
Port: svc.pprofPort,
|
Port: svc.pprofPort,
|
||||||
Enabled: svc.pprof != nil,
|
Enabled: svc.pprof != nil,
|
||||||
|
@ -89,13 +88,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||||
// relaunch updates the web service in the configuration manager and starts it.
|
// relaunch updates the web service in the configuration manager and starts it.
|
||||||
// It is intended to be used as a goroutine.
|
// It is intended to be used as a goroutine.
|
||||||
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
|
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
|
||||||
defer log.OnPanic("websvc: relaunching")
|
defer slogutil.RecoverAndLog(ctx, svc.logger)
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := svc.confMgr.UpdateWeb(ctx, newConf)
|
err := svc.confMgr.UpdateWeb(ctx, newConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("websvc: updating web: %s", err)
|
svc.logger.ErrorContext(ctx, "updating web", slogutil.KeyError, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -106,18 +105,18 @@ func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, new
|
||||||
var newSvc agh.ServiceWithConfig[*Config]
|
var newSvc agh.ServiceWithConfig[*Config]
|
||||||
for newSvc = svc.confMgr.Web(); newSvc == svc; {
|
for newSvc = svc.confMgr.Web(); newSvc == svc; {
|
||||||
if time.Since(updStart) >= maxUpdDur {
|
if time.Since(updStart) >= maxUpdDur {
|
||||||
log.Error("websvc: failed to update svc after %s", maxUpdDur)
|
svc.logger.ErrorContext(ctx, "failed to update service on time", "duration", maxUpdDur)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("websvc: waiting for new websvc to be configured")
|
svc.logger.DebugContext(ctx, "waiting for new service")
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newSvc.Start()
|
err = newSvc.Start(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("websvc: new svc failed to start with error: %s", err)
|
svc.logger.ErrorContext(ctx, "new service failed", slogutil.KeyError, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -27,14 +28,15 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := websvc.New(&websvc.Config{
|
svc, err := websvc.New(&websvc.Config{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
Pprof: &websvc.PprofConfig{
|
Pprof: &websvc.PprofConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
TLS: &tls.Config{
|
TLS: &tls.Config{
|
||||||
Certificates: []tls.Certificate{{}},
|
Certificates: []tls.Certificate{{}},
|
||||||
},
|
},
|
||||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
ForceHTTPS: true,
|
ForceHTTPS: true,
|
||||||
})
|
})
|
||||||
|
@ -48,7 +50,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: urlutil.SchemeHTTP,
|
Scheme: urlutil.SchemeHTTP,
|
||||||
Host: addr.String(),
|
Host: addr.String(),
|
||||||
Path: websvc.PathV1SettingsHTTP,
|
Path: websvc.PathPatternV1SettingsHTTP,
|
||||||
}
|
}
|
||||||
|
|
||||||
req := jobj{
|
req := jobj{
|
||||||
|
|
|
@ -2,15 +2,11 @@ package websvc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/golibs/httphdr"
|
"github.com/AdguardTeam/golibs/httphdr"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middlewares
|
|
||||||
|
|
||||||
// jsonMw sets the content type of the response to application/json.
|
// jsonMw sets the content type of the response to application/json.
|
||||||
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||||
f := func(w http.ResponseWriter, r *http.Request) {
|
f := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -21,18 +17,3 @@ func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||||
|
|
||||||
return http.HandlerFunc(f)
|
return http.HandlerFunc(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// logMw logs the queries with level debug.
|
|
||||||
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
|
|
||||||
f := func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
start := time.Now()
|
|
||||||
m, u := r.Method, r.RequestURI
|
|
||||||
|
|
||||||
log.Debug("websvc: %s %s started", m, u)
|
|
||||||
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
|
|
||||||
|
|
||||||
h.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return http.HandlerFunc(f)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
package websvc
|
|
||||||
|
|
||||||
// Path constants
|
|
||||||
const (
|
|
||||||
PathRoot = "/"
|
|
||||||
PathFrontend = "/*filepath"
|
|
||||||
|
|
||||||
PathHealthCheck = "/health-check"
|
|
||||||
|
|
||||||
PathV1SettingsAll = "/api/v1/settings/all"
|
|
||||||
PathV1SettingsDNS = "/api/v1/settings/dns"
|
|
||||||
PathV1SettingsHTTP = "/api/v1/settings/http"
|
|
||||||
PathV1SystemInfo = "/api/v1/system/info"
|
|
||||||
)
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
package websvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Path pattern constants.
|
||||||
|
const (
|
||||||
|
PathPatternFrontend = "/"
|
||||||
|
PathPatternHealthCheck = "/health-check"
|
||||||
|
PathPatternV1SettingsAll = "/api/v1/settings/all"
|
||||||
|
PathPatternV1SettingsDNS = "/api/v1/settings/dns"
|
||||||
|
PathPatternV1SettingsHTTP = "/api/v1/settings/http"
|
||||||
|
PathPatternV1SystemInfo = "/api/v1/system/info"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route pattern constants.
|
||||||
|
const (
|
||||||
|
routePatternFrontend = http.MethodGet + " " + PathPatternFrontend
|
||||||
|
routePatternGetV1SettingsAll = http.MethodGet + " " + PathPatternV1SettingsAll
|
||||||
|
routePatternGetV1SystemInfo = http.MethodGet + " " + PathPatternV1SystemInfo
|
||||||
|
routePatternHealthCheck = http.MethodGet + " " + PathPatternHealthCheck
|
||||||
|
routePatternPatchV1SettingsDNS = http.MethodPatch + " " + PathPatternV1SettingsDNS
|
||||||
|
routePatternPatchV1SettingsHTTP = http.MethodPatch + " " + PathPatternV1SettingsHTTP
|
||||||
|
)
|
||||||
|
|
||||||
|
// route registers all necessary handlers in mux.
|
||||||
|
func (svc *Service) route(mux *http.ServeMux) {
|
||||||
|
routes := []struct {
|
||||||
|
handler http.Handler
|
||||||
|
pattern string
|
||||||
|
isJSON bool
|
||||||
|
}{{
|
||||||
|
handler: httputil.HealthCheckHandler,
|
||||||
|
pattern: routePatternHealthCheck,
|
||||||
|
isJSON: false,
|
||||||
|
}, {
|
||||||
|
handler: http.FileServer(http.FS(svc.frontend)),
|
||||||
|
pattern: routePatternFrontend,
|
||||||
|
isJSON: false,
|
||||||
|
}, {
|
||||||
|
handler: http.HandlerFunc(svc.handleGetSettingsAll),
|
||||||
|
pattern: routePatternGetV1SettingsAll,
|
||||||
|
isJSON: true,
|
||||||
|
}, {
|
||||||
|
handler: http.HandlerFunc(svc.handlePatchSettingsDNS),
|
||||||
|
pattern: routePatternPatchV1SettingsDNS,
|
||||||
|
isJSON: true,
|
||||||
|
}, {
|
||||||
|
handler: http.HandlerFunc(svc.handlePatchSettingsHTTP),
|
||||||
|
pattern: routePatternPatchV1SettingsHTTP,
|
||||||
|
isJSON: true,
|
||||||
|
}, {
|
||||||
|
handler: http.HandlerFunc(svc.handleGetV1SystemInfo),
|
||||||
|
pattern: routePatternGetV1SystemInfo,
|
||||||
|
isJSON: true,
|
||||||
|
}}
|
||||||
|
|
||||||
|
logMw := httputil.NewLogMiddleware(svc.logger, slog.LevelDebug)
|
||||||
|
for _, r := range routes {
|
||||||
|
var hdlr http.Handler
|
||||||
|
if r.isJSON {
|
||||||
|
hdlr = jsonMw(r.handler)
|
||||||
|
} else {
|
||||||
|
hdlr = r.handler
|
||||||
|
}
|
||||||
|
|
||||||
|
mux.Handle(r.pattern, logMw.Wrap(hdlr))
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
package websvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// server contains an *http.Server as well as entities and data associated with
|
||||||
|
// it.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Join with similar structs in other projects and move to
|
||||||
|
// golibs/netutil/httputil.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Once the above standardization is complete, consider
|
||||||
|
// merging debugsvc and websvc into a single httpsvc.
|
||||||
|
type server struct {
|
||||||
|
// mu protects http, logger, tcpListener, and url.
|
||||||
|
mu *sync.Mutex
|
||||||
|
http *http.Server
|
||||||
|
logger *slog.Logger
|
||||||
|
tcpListener *net.TCPListener
|
||||||
|
url *url.URL
|
||||||
|
|
||||||
|
tlsConf *tls.Config
|
||||||
|
initialAddr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// loggerKeyServer is the key used by [server] to identify itself.
|
||||||
|
const loggerKeyServer = "server"
|
||||||
|
|
||||||
|
// newServer returns a *server that is ready to serve HTTP queries. The TCP
|
||||||
|
// listener is not started. handler must not be nil.
|
||||||
|
func newServer(
|
||||||
|
baseLogger *slog.Logger,
|
||||||
|
initialAddr netip.AddrPort,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
handler http.Handler,
|
||||||
|
timeout time.Duration,
|
||||||
|
) (s *server) {
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: urlutil.SchemeHTTP,
|
||||||
|
Host: initialAddr.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConf != nil {
|
||||||
|
u.Scheme = urlutil.SchemeHTTPS
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := baseLogger.With(loggerKeyServer, u)
|
||||||
|
|
||||||
|
return &server{
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
http: &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: timeout,
|
||||||
|
ReadHeaderTimeout: timeout,
|
||||||
|
WriteTimeout: timeout,
|
||||||
|
IdleTimeout: timeout,
|
||||||
|
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||||
|
},
|
||||||
|
logger: logger,
|
||||||
|
url: u,
|
||||||
|
|
||||||
|
tlsConf: tlsConf,
|
||||||
|
initialAddr: initialAddr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// localAddr returns the local address of the server if the server has started
|
||||||
|
// listening; otherwise, it returns nil.
|
||||||
|
func (s *server) localAddr() (addr net.Addr) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if l := s.tcpListener; l != nil {
|
||||||
|
return l.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve starts s. baseLogger is used as a base logger for s. If s fails to
|
||||||
|
// serve with anything other than [http.ErrServerClosed], it causes an unhandled
|
||||||
|
// panic. It is intended to be used as a goroutine.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Improve error handling.
|
||||||
|
func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) {
|
||||||
|
l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr))
|
||||||
|
if err != nil {
|
||||||
|
s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
panic(fmt.Errorf("websvc: listening tcp: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.tcpListener = l
|
||||||
|
|
||||||
|
// Reassign the address in case the port was zero.
|
||||||
|
s.url.Host = l.Addr().String()
|
||||||
|
s.logger = baseLogger.With(loggerKeyServer, s.url)
|
||||||
|
s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.logger.InfoContext(ctx, "starting")
|
||||||
|
defer s.logger.InfoContext(ctx, "started")
|
||||||
|
|
||||||
|
err = s.http.Serve(l)
|
||||||
|
if err == nil || errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err)
|
||||||
|
|
||||||
|
panic(fmt.Errorf("websvc: serving: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdown shuts s down.
|
||||||
|
func (s *server) shutdown(ctx context.Context) (err error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
err = s.http.Shutdown(ctx)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the listener separately, as it might not have been closed if the
|
||||||
|
// context has been canceled.
|
||||||
|
//
|
||||||
|
// NOTE: The listener could remain uninitialized if [net.ListenTCP] failed
|
||||||
|
// in [s.serve].
|
||||||
|
if l := s.tcpListener; l != nil {
|
||||||
|
err = l.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
package websvc_test
|
package websvc_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -13,6 +12,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -29,16 +29,10 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||||
BootstrapPreferIPv6: true,
|
BootstrapPreferIPv6: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
|
||||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
|
||||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
|
||||||
Timeout: aghhttp.JSONDuration(5 * time.Second),
|
|
||||||
ForceHTTPS: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
confMgr := newConfigManager()
|
confMgr := newConfigManager()
|
||||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||||
c, err := dnssvc.New(&dnssvc.Config{
|
c, err := dnssvc.New(&dnssvc.Config{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
Addresses: wantDNS.Addresses,
|
Addresses: wantDNS.Addresses,
|
||||||
UpstreamServers: wantDNS.UpstreamServers,
|
UpstreamServers: wantDNS.UpstreamServers,
|
||||||
BootstrapServers: wantDNS.BootstrapServers,
|
BootstrapServers: wantDNS.BootstrapServers,
|
||||||
|
@ -50,34 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := websvc.New(&websvc.Config{
|
svc, addr := newTestServer(t, confMgr)
|
||||||
Pprof: &websvc.PprofConfig{
|
u := &url.URL{
|
||||||
Enabled: false,
|
Scheme: urlutil.SchemeHTTP,
|
||||||
},
|
Host: addr.String(),
|
||||||
TLS: &tls.Config{
|
Path: websvc.PathPatternV1SettingsAll,
|
||||||
Certificates: []tls.Certificate{{}},
|
}
|
||||||
},
|
|
||||||
Addresses: wantWeb.Addresses,
|
|
||||||
SecureAddresses: wantWeb.SecureAddresses,
|
|
||||||
Timeout: time.Duration(wantWeb.Timeout),
|
|
||||||
ForceHTTPS: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
|
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
_, addr := newTestServer(t, confMgr)
|
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
||||||
u := &url.URL{
|
Addresses: []netip.AddrPort{addr},
|
||||||
Scheme: urlutil.SchemeHTTP,
|
SecureAddresses: nil,
|
||||||
Host: addr.String(),
|
Timeout: aghhttp.JSONDuration(testTimeout),
|
||||||
Path: websvc.PathV1SettingsAll,
|
ForceHTTPS: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
body := httpGet(t, u, http.StatusOK)
|
body := httpGet(t, u, http.StatusOK)
|
||||||
resp := &websvc.RespGetV1SettingsAll{}
|
resp := &websvc.RespGetV1SettingsAll{}
|
||||||
err = json.Unmarshal(body, resp)
|
err := json.Unmarshal(body, resp)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, wantDNS, resp.DNS)
|
assert.Equal(t, wantDNS, resp.DNS)
|
||||||
|
|
|
@ -20,7 +20,7 @@ func TestService_handleGetV1SystemInfo(t *testing.T) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: urlutil.SchemeHTTP,
|
Scheme: urlutil.SchemeHTTP,
|
||||||
Host: addr.String(),
|
Host: addr.String(),
|
||||||
Path: websvc.PathV1SystemInfo,
|
Path: websvc.PathPatternV1SystemInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
body := httpGet(t, u, http.StatusOK)
|
body := httpGet(t, u, http.StatusOK)
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
package websvc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Wait Listener
|
|
||||||
|
|
||||||
// waitListener is a wrapper around a listener that also calls wg.Done() on the
|
|
||||||
// first call to Accept. It is useful in situations where it is important to
|
|
||||||
// catch the precise moment of the first call to Accept, for example when
|
|
||||||
// starting an HTTP server.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Move to aghnet?
|
|
||||||
type waitListener struct {
|
|
||||||
net.Listener
|
|
||||||
|
|
||||||
firstAcceptWG *sync.WaitGroup
|
|
||||||
firstAcceptOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ net.Listener = (*waitListener)(nil)
|
|
||||||
|
|
||||||
// Accept implements the [net.Listener] interface for *waitListener.
|
|
||||||
func (l *waitListener) Accept() (conn net.Conn, err error) {
|
|
||||||
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
|
|
||||||
|
|
||||||
return l.Listener.Accept()
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
package websvc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWaitListener_Accept(t *testing.T) {
|
|
||||||
var accepted atomic.Bool
|
|
||||||
var l net.Listener = &fakenet.Listener{
|
|
||||||
OnAccept: func() (conn net.Conn, err error) {
|
|
||||||
accepted.Store(true)
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
OnAddr: func() (addr net.Addr) { panic("not implemented") },
|
|
||||||
OnClose: func() (err error) { panic("not implemented") },
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
wg.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
var wrapper net.Listener = &waitListener{
|
|
||||||
Listener: l,
|
|
||||||
firstAcceptWG: wg,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = wrapper.Accept()
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
|
|
||||||
}
|
|
|
@ -10,22 +10,18 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/mathutil"
|
|
||||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||||
httptreemux "github.com/dimfeld/httptreemux/v5"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConfigManager is the configuration manager interface.
|
// ConfigManager is the configuration manager interface.
|
||||||
|
@ -40,13 +36,14 @@ type ConfigManager interface {
|
||||||
// Service is the AdGuard Home web service. A nil *Service is a valid
|
// Service is the AdGuard Home web service. A nil *Service is a valid
|
||||||
// [agh.Service] that does nothing.
|
// [agh.Service] that does nothing.
|
||||||
type Service struct {
|
type Service struct {
|
||||||
|
logger *slog.Logger
|
||||||
confMgr ConfigManager
|
confMgr ConfigManager
|
||||||
frontend fs.FS
|
frontend fs.FS
|
||||||
tls *tls.Config
|
tls *tls.Config
|
||||||
pprof *http.Server
|
pprof *server
|
||||||
start time.Time
|
start time.Time
|
||||||
overrideAddr netip.AddrPort
|
overrideAddr netip.AddrPort
|
||||||
servers []*http.Server
|
servers []*server
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
pprofPort uint16
|
pprofPort uint16
|
||||||
forceHTTPS bool
|
forceHTTPS bool
|
||||||
|
@ -64,6 +61,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
svc = &Service{
|
svc = &Service{
|
||||||
|
logger: c.Logger,
|
||||||
confMgr: c.ConfigManager,
|
confMgr: c.ConfigManager,
|
||||||
frontend: c.Frontend,
|
frontend: c.Frontend,
|
||||||
tls: c.TLS,
|
tls: c.TLS,
|
||||||
|
@ -73,17 +71,18 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
forceHTTPS: c.ForceHTTPS,
|
forceHTTPS: c.ForceHTTPS,
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := newMux(svc)
|
mux := http.NewServeMux()
|
||||||
|
svc.route(mux)
|
||||||
|
|
||||||
if svc.overrideAddr != (netip.AddrPort{}) {
|
if svc.overrideAddr != (netip.AddrPort{}) {
|
||||||
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
|
svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)}
|
||||||
} else {
|
} else {
|
||||||
for _, a := range c.Addresses {
|
for _, a := range c.Addresses {
|
||||||
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
|
svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range c.SecureAddresses {
|
for _, a := range c.SecureAddresses {
|
||||||
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
|
svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,96 +111,7 @@ func (svc *Service) setupPprof(c *PprofConfig) {
|
||||||
svc.pprofPort = c.Port
|
svc.pprofPort = c.Port
|
||||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
|
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
|
||||||
|
|
||||||
// TODO(a.garipov): Consider making pprof timeout configurable.
|
svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
|
||||||
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSrv returns a new *http.Server with the given parameters.
|
|
||||||
func newSrv(
|
|
||||||
addr netip.AddrPort,
|
|
||||||
tlsConf *tls.Config,
|
|
||||||
h http.Handler,
|
|
||||||
timeout time.Duration,
|
|
||||||
) (srv *http.Server) {
|
|
||||||
addrStr := addr.String()
|
|
||||||
srv = &http.Server{
|
|
||||||
Addr: addrStr,
|
|
||||||
Handler: h,
|
|
||||||
TLSConfig: tlsConf,
|
|
||||||
ReadTimeout: timeout,
|
|
||||||
WriteTimeout: timeout,
|
|
||||||
IdleTimeout: timeout,
|
|
||||||
ReadHeaderTimeout: timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tlsConf == nil {
|
|
||||||
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
|
|
||||||
} else {
|
|
||||||
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
|
|
||||||
}
|
|
||||||
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
|
|
||||||
// service.
|
|
||||||
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
|
|
||||||
mux = httptreemux.NewContextMux()
|
|
||||||
|
|
||||||
routes := []struct {
|
|
||||||
handler http.HandlerFunc
|
|
||||||
method string
|
|
||||||
pattern string
|
|
||||||
isJSON bool
|
|
||||||
}{{
|
|
||||||
handler: svc.handleGetHealthCheck,
|
|
||||||
method: http.MethodGet,
|
|
||||||
pattern: PathHealthCheck,
|
|
||||||
isJSON: false,
|
|
||||||
}, {
|
|
||||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
|
||||||
method: http.MethodGet,
|
|
||||||
pattern: PathFrontend,
|
|
||||||
isJSON: false,
|
|
||||||
}, {
|
|
||||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
|
||||||
method: http.MethodGet,
|
|
||||||
pattern: PathRoot,
|
|
||||||
isJSON: false,
|
|
||||||
}, {
|
|
||||||
handler: svc.handleGetSettingsAll,
|
|
||||||
method: http.MethodGet,
|
|
||||||
pattern: PathV1SettingsAll,
|
|
||||||
isJSON: true,
|
|
||||||
}, {
|
|
||||||
handler: svc.handlePatchSettingsDNS,
|
|
||||||
method: http.MethodPatch,
|
|
||||||
pattern: PathV1SettingsDNS,
|
|
||||||
isJSON: true,
|
|
||||||
}, {
|
|
||||||
handler: svc.handlePatchSettingsHTTP,
|
|
||||||
method: http.MethodPatch,
|
|
||||||
pattern: PathV1SettingsHTTP,
|
|
||||||
isJSON: true,
|
|
||||||
}, {
|
|
||||||
handler: svc.handleGetV1SystemInfo,
|
|
||||||
method: http.MethodGet,
|
|
||||||
pattern: PathV1SystemInfo,
|
|
||||||
isJSON: true,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, r := range routes {
|
|
||||||
var hdlr http.Handler
|
|
||||||
if r.isJSON {
|
|
||||||
hdlr = jsonMw(r.handler)
|
|
||||||
} else {
|
|
||||||
hdlr = r.handler
|
|
||||||
}
|
|
||||||
|
|
||||||
mux.Handle(r.method, r.pattern, logMw(hdlr))
|
|
||||||
}
|
|
||||||
|
|
||||||
return mux
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addrs returns all addresses on which this server serves the HTTP API. addrs
|
// addrs returns all addresses on which this server serves the HTTP API. addrs
|
||||||
|
@ -214,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, srv := range svc.servers {
|
for _, srv := range svc.servers {
|
||||||
// Use MustParseAddrPort, since no errors should technically happen
|
addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
|
||||||
// here, because all servers must have a valid address.
|
if addrPort == (netip.AddrPort{}) {
|
||||||
addrPort := netip.MustParseAddrPort(srv.Addr)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
|
if srv.tlsConf == nil {
|
||||||
// of relying only on the nilness of TLSConfig, check the length of the
|
|
||||||
// certificates field as well.
|
|
||||||
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
|
|
||||||
addrs = append(addrs, addrPort)
|
addrs = append(addrs, addrPort)
|
||||||
} else {
|
} else {
|
||||||
secureAddrs = append(secureAddrs, addrPort)
|
secureAddrs = append(secureAddrs, addrPort)
|
||||||
|
@ -231,74 +139,60 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||||
return addrs, secureAddrs
|
return addrs, secureAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
|
|
||||||
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
_, _ = io.WriteString(w, "OK")
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ agh.Service = (*Service)(nil)
|
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
|
||||||
|
|
||||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||||
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
||||||
// writing error messages to the log.
|
// writing error messages to the log.
|
||||||
func (svc *Service) Start() (err error) {
|
//
|
||||||
|
// TODO(a.garipov): Use the context for cancelation as well.
|
||||||
|
func (svc *Service) Start(ctx context.Context) (err error) {
|
||||||
if svc == nil {
|
if svc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pprofEnabled := svc.pprof != nil
|
svc.logger.InfoContext(ctx, "starting")
|
||||||
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
|
defer svc.logger.InfoContext(ctx, "started")
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
wg.Add(srvNum)
|
|
||||||
for _, srv := range svc.servers {
|
for _, srv := range svc.servers {
|
||||||
go serve(srv, wg)
|
go srv.serve(ctx, svc.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pprofEnabled {
|
if svc.pprof != nil {
|
||||||
go serve(svc.pprof, wg)
|
go svc.pprof.serve(ctx, svc.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
return svc.wait(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait waits until either the context is canceled or all servers have started.
|
||||||
|
func (svc *Service) wait(ctx context.Context) (err error) {
|
||||||
|
for !svc.serversHaveStarted() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
// Wait and let the other goroutines do their job.
|
||||||
|
runtime.Gosched()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// serve starts and runs srv and writes all errors into its log.
|
// serversHaveStarted returns true if all servers have started serving.
|
||||||
func serve(srv *http.Server, wg *sync.WaitGroup) {
|
func (svc *Service) serversHaveStarted() (started bool) {
|
||||||
addr := srv.Addr
|
started = len(svc.servers) != 0
|
||||||
defer log.OnPanic(addr)
|
for _, srv := range svc.servers {
|
||||||
|
started = started && srv.localAddr() != nil
|
||||||
var proto string
|
|
||||||
var l net.Listener
|
|
||||||
var err error
|
|
||||||
if srv.TLSConfig == nil {
|
|
||||||
proto = "http"
|
|
||||||
l, err = net.Listen("tcp", addr)
|
|
||||||
} else {
|
|
||||||
proto = "https"
|
|
||||||
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the server's address in case the address had the port zero, which
|
if svc.pprof != nil {
|
||||||
// would mean that a random available port was automatically chosen.
|
started = started && svc.pprof.localAddr() != nil
|
||||||
srv.Addr = l.Addr().String()
|
|
||||||
|
|
||||||
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
|
|
||||||
|
|
||||||
l = &waitListener{
|
|
||||||
Listener: l,
|
|
||||||
firstAcceptWG: wg,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = srv.Serve(l)
|
return started
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||||
|
@ -308,20 +202,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
svc.logger.InfoContext(ctx, "shutting down")
|
||||||
|
defer svc.logger.InfoContext(ctx, "shut down")
|
||||||
|
|
||||||
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
|
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
for _, srv := range svc.servers {
|
for _, srv := range svc.servers {
|
||||||
shutdownErr := srv.Shutdown(ctx)
|
shutdownErr := srv.shutdown(ctx)
|
||||||
if shutdownErr != nil {
|
if shutdownErr != nil {
|
||||||
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.pprof != nil {
|
if svc.pprof != nil {
|
||||||
shutdownErr := svc.pprof.Shutdown(ctx)
|
shutdownErr := svc.pprof.shutdown(ctx)
|
||||||
if shutdownErr != nil {
|
if shutdownErr != nil {
|
||||||
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
|
errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
package websvc
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// testTimeout is the common timeout for tests.
|
|
||||||
const testTimeout = 1 * time.Second
|
|
|
@ -15,6 +15,8 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||||
|
@ -22,10 +24,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
testutil.DiscardLogOutput(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// testTimeout is the common timeout for tests.
|
// testTimeout is the common timeout for tests.
|
||||||
const testTimeout = 1 * time.Second
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
@ -81,8 +79,6 @@ func newConfigManager() (m *configManager) {
|
||||||
// newTestServer creates and starts a new web service instance as well as its
|
// newTestServer creates and starts a new web service instance as well as its
|
||||||
// sole address. It also registers a cleanup procedure, which shuts the
|
// sole address. It also registers a cleanup procedure, which shuts the
|
||||||
// instance down.
|
// instance down.
|
||||||
//
|
|
||||||
// TODO(a.garipov): Use svc or remove it.
|
|
||||||
func newTestServer(
|
func newTestServer(
|
||||||
t testing.TB,
|
t testing.TB,
|
||||||
confMgr websvc.ConfigManager,
|
confMgr websvc.ConfigManager,
|
||||||
|
@ -90,6 +86,7 @@ func newTestServer(
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c := &websvc.Config{
|
c := &websvc.Config{
|
||||||
|
Logger: slogutil.NewDiscardLogger(),
|
||||||
Pprof: &websvc.PprofConfig{
|
Pprof: &websvc.PprofConfig{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
},
|
},
|
||||||
|
@ -108,7 +105,7 @@ func newTestServer(
|
||||||
svc, err := websvc.New(c)
|
svc, err := websvc.New(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = svc.Start()
|
err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||||
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
@ -184,10 +181,10 @@ func TestService_Start_getHealthCheck(t *testing.T) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: urlutil.SchemeHTTP,
|
Scheme: urlutil.SchemeHTTP,
|
||||||
Host: addr.String(),
|
Host: addr.String(),
|
||||||
Path: websvc.PathHealthCheck,
|
Path: websvc.PathPatternHealthCheck,
|
||||||
}
|
}
|
||||||
|
|
||||||
body := httpGet(t, u, http.StatusOK)
|
body := httpGet(t, u, http.StatusOK)
|
||||||
|
|
||||||
assert.Equal(t, []byte("OK"), body)
|
assert.Equal(t, []byte(httputil.HealthCheckHandler), body)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue