Pull request 1899: nextapi-pidfile-webaddr

Squashed commit of the following:

commit 73b97b638016dd3992376c2cd7d11b2e85b2c3a4
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 29 18:43:05 2023 +0300

    next: use maybe; sync conf

commit 99e18b8fbfad11343a1e66f746085d54be7aafea
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jun 29 18:13:13 2023 +0300

    next: add local frontend, pidfile, webaddr
This commit is contained in:
Ainar Garipov 2023-06-29 19:10:39 +03:00
parent 39f5c50acd
commit ee8eb1d8a6
7 changed files with 230 additions and 133 deletions

View File

@ -23,4 +23,5 @@ http:
secure_addresses: [] secure_addresses: []
timeout: 5s timeout: 5s
force_https: true force_https: true
verbose: true log:
verbose: true

View File

@ -16,7 +16,7 @@ import (
) )
// Main is the entry point of AdGuard Home. // Main is the entry point of AdGuard Home.
func Main(frontend fs.FS) { func Main(embeddedFrontend fs.FS) {
start := time.Now() start := time.Now()
cmdName := os.Args[0] cmdName := os.Args[0]
@ -37,7 +37,17 @@ func Main(frontend fs.FS) {
check(err) check(err)
} }
confMgr, err := newConfigMgr(opts.confFile, frontend, start) frontend, err := frontendFromOpts(opts, embeddedFrontend)
check(err)
confMgrConf := &configmgr.Config{
Frontend: frontend,
WebAddr: opts.webAddr,
Start: start,
FileName: opts.confFile,
}
confMgr, err := newConfigMgr(confMgrConf)
check(err) check(err)
web := confMgr.Web() web := confMgr.Web()
@ -49,9 +59,8 @@ func Main(frontend fs.FS) {
check(err) check(err)
sigHdlr := newSignalHandler( sigHdlr := newSignalHandler(
opts.confFile, confMgrConf,
frontend, opts.pidFile,
start,
web, web,
dns, dns,
) )
@ -71,11 +80,11 @@ func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
// newConfigMgr returns a new configuration manager using defaultTimeout as the // newConfigMgr returns a new configuration manager using defaultTimeout as the
// context timeout. // context timeout.
func newConfigMgr(confFile string, frontend fs.FS, start time.Time) (m *configmgr.Manager, err error) { func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) {
ctx, cancel := ctxWithDefaultTimeout() ctx, cancel := ctxWithDefaultTimeout()
defer cancel() defer cancel()
return configmgr.New(ctx, confFile, frontend, start) return configmgr.New(ctx, c)
} }
// check is a simple error-checking helper. It must only be used within Main. // check is a simple error-checking helper. It must only be used within Main.

View File

@ -1,15 +1,18 @@
package cmd package cmd
import ( import (
"encoding"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/fs"
"net/netip" "net/netip"
"os" "os"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -26,8 +29,6 @@ type options struct {
logFile string logFile string
// pidFile is the path to the file where to store the PID. // pidFile is the path to the file where to store the PID.
//
// TODO(a.garipov): Use.
pidFile string pidFile string
// serviceAction is the service control action to perform: // serviceAction is the service control action to perform:
@ -50,10 +51,8 @@ type options struct {
// other configuration is read, so all relative paths are relative to it. // other configuration is read, so all relative paths are relative to it.
workDir string workDir string
// webAddrs contains the addresses on which to serve the web UI. // webAddr contains the address on which to serve the web UI.
// webAddr netip.AddrPort
// TODO(a.garipov): Use.
webAddrs []netip.AddrPort
// checkConfig, if true, instructs AdGuard Home to check the configuration // checkConfig, if true, instructs AdGuard Home to check the configuration
// file, optionally print an error message to stdout, and exit with a // file, optionally print an error message to stdout, and exit with a
@ -103,7 +102,7 @@ const (
pidFileIdx pidFileIdx
serviceActionIdx serviceActionIdx
workDirIdx workDirIdx
webAddrsIdx webAddrIdx
checkConfigIdx checkConfigIdx
disableUpdateIdx disableUpdateIdx
glinetModeIdx glinetModeIdx
@ -172,13 +171,12 @@ var commandLineOptions = []*commandLineOption{
valueType: "path", valueType: "path",
}, },
webAddrsIdx: { webAddrIdx: {
defaultValue: []netip.AddrPort(nil), defaultValue: netip.AddrPort{},
description: `Address(es) to serve the web UI on, in the host:port format. ` + description: `Address to serve the web UI on, in the host:port format.`,
`Can be used multiple times.`, long: "web-addr",
long: "web-addr", short: "",
short: "", valueType: "host:port",
valueType: "host:port",
}, },
checkConfigIdx: { checkConfigIdx: {
@ -258,7 +256,7 @@ func parseOptions(cmdName string, args []string) (opts *options, err error) {
pidFileIdx: &opts.pidFile, pidFileIdx: &opts.pidFile,
serviceActionIdx: &opts.serviceAction, serviceActionIdx: &opts.serviceAction,
workDirIdx: &opts.workDir, workDirIdx: &opts.workDir,
webAddrsIdx: &opts.webAddrs, webAddrIdx: &opts.webAddr,
checkConfigIdx: &opts.checkConfig, checkConfigIdx: &opts.checkConfig,
disableUpdateIdx: &opts.disableUpdate, disableUpdateIdx: &opts.disableUpdate,
glinetModeIdx: &opts.glinetMode, glinetModeIdx: &opts.glinetMode,
@ -291,23 +289,16 @@ func addOption(flags *flag.FlagSet, fieldPtr any, o *commandLineOption) {
if o.short != "" { if o.short != "" {
flags.StringVar(fieldPtr, o.short, o.defaultValue.(string), o.description) flags.StringVar(fieldPtr, o.short, o.defaultValue.(string), o.description)
} }
case *[]netip.AddrPort:
flags.Func(o.long, o.description, func(s string) (err error) {
addr, err := netip.ParseAddrPort(s)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
*fieldPtr = append(*fieldPtr, addr)
return nil
})
case *bool: case *bool:
flags.BoolVar(fieldPtr, o.long, o.defaultValue.(bool), o.description) flags.BoolVar(fieldPtr, o.long, o.defaultValue.(bool), o.description)
if o.short != "" { if o.short != "" {
flags.BoolVar(fieldPtr, o.short, o.defaultValue.(bool), o.description) flags.BoolVar(fieldPtr, o.short, o.defaultValue.(bool), o.description)
} }
case encoding.TextUnmarshaler:
flags.TextVar(fieldPtr, o.long, o.defaultValue.(encoding.TextMarshaler), o.description)
if o.short != "" {
flags.TextVar(fieldPtr, o.short, o.defaultValue.(encoding.TextMarshaler), o.description)
}
default: default:
panic(fmt.Errorf("unexpected field pointer type %T", fieldPtr)) panic(fmt.Errorf("unexpected field pointer type %T", fieldPtr))
} }
@ -380,13 +371,13 @@ func processOptions(
) (exitCode int, needExit bool) { ) (exitCode int, needExit bool) {
if parseErr != nil { if parseErr != nil {
// Assume that usage has already been printed. // Assume that usage has already been printed.
return 2, true return statusArgumentError, true
} }
if opts.help { if opts.help {
usage(cmdName, os.Stdout) usage(cmdName, os.Stdout)
return 0, true return statusSuccess, true
} }
if opts.version { if opts.version {
@ -396,7 +387,7 @@ func processOptions(
fmt.Printf("AdGuard Home %s\n", version.Version()) fmt.Printf("AdGuard Home %s\n", version.Version())
} }
return 0, true return statusSuccess, true
} }
if opts.checkConfig { if opts.checkConfig {
@ -404,11 +395,24 @@ func processOptions(
if err != nil { if err != nil {
_, _ = io.WriteString(os.Stdout, err.Error()+"\n") _, _ = io.WriteString(os.Stdout, err.Error()+"\n")
return 1, true return statusError, true
} }
return 0, true return statusSuccess, true
} }
return 0, false return 0, false
} }
// frontendFromOpts returns the frontend to use based on the options.
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) {
const frontendSubdir = "build/static"
if opts.localFrontend {
log.Info("warning: using local frontend files")
return os.DirFS(frontendSubdir), nil
}
return fs.Sub(embeddedFrontend, frontendSubdir)
}

View File

@ -1,29 +1,27 @@
package cmd package cmd
import ( import (
"io/fs"
"os" "os"
"time" "strconv"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe"
) )
// signalHandler processes incoming signals and shuts services down. // signalHandler processes incoming signals and shuts services down.
type signalHandler struct { type signalHandler struct {
// confMgrConf contains the configuration parameters for the configuration
// manager.
confMgrConf *configmgr.Config
// signal is the channel to which OS signals are sent. // signal is the channel to which OS signals are sent.
signal chan os.Signal signal chan os.Signal
// confFile is the path to the configuration file. // pidFile is the path to the file where to store the PID, if any.
confFile string pidFile string
// frontend is the filesystem with the frontend and other statically
// compiled files.
frontend fs.FS
// start is the time at which AdGuard Home has been started.
start time.Time
// services are the services that are shut down before application exiting. // services are the services that are shut down before application exiting.
services []agh.Service services []agh.Service
@ -33,6 +31,8 @@ type signalHandler struct {
func (h *signalHandler) handle() { func (h *signalHandler) handle() {
defer log.OnPanic("signalHandler.handle") defer log.OnPanic("signalHandler.handle")
h.writePID()
for sig := range h.signal { for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig) log.Info("sighdlr: received signal %q", sig)
@ -40,6 +40,8 @@ func (h *signalHandler) handle() {
h.reconfigure() h.reconfigure()
} else if aghos.IsShutdownSignal(sig) { } else if aghos.IsShutdownSignal(sig) {
status := h.shutdown() status := h.shutdown()
h.removePID()
log.Info("sighdlr: exiting with status %d", status) log.Info("sighdlr: exiting with status %d", status)
os.Exit(status) os.Exit(status)
@ -62,7 +64,7 @@ func (h *signalHandler) reconfigure() {
// reconfigured without the full shutdown, and the error handling is // reconfigured without the full shutdown, and the error handling is
// currently not the best. // currently not the best.
confMgr, err := newConfigMgr(h.confFile, h.frontend, h.start) confMgr, err := newConfigMgr(h.confMgrConf)
check(err) check(err)
web := confMgr.Web() web := confMgr.Web()
@ -83,8 +85,9 @@ func (h *signalHandler) reconfigure() {
// Exit status constants. // Exit status constants.
const ( const (
statusSuccess = 0 statusSuccess = 0
statusError = 1 statusError = 1
statusArgumentError = 2
) )
// shutdown gracefully shuts down all services. // shutdown gracefully shuts down all services.
@ -108,17 +111,15 @@ func (h *signalHandler) shutdown() (status int) {
// newSignalHandler returns a new signalHandler that shuts down svcs. // newSignalHandler returns a new signalHandler that shuts down svcs.
func newSignalHandler( func newSignalHandler(
confFile string, confMgrConf *configmgr.Config,
frontend fs.FS, pidFile string,
start time.Time,
svcs ...agh.Service, svcs ...agh.Service,
) (h *signalHandler) { ) (h *signalHandler) {
h = &signalHandler{ h = &signalHandler{
signal: make(chan os.Signal, 1), confMgrConf: confMgrConf,
confFile: confFile, signal: make(chan os.Signal, 1),
frontend: frontend, pidFile: pidFile,
start: start, services: svcs,
services: svcs,
} }
aghos.NotifyShutdownSignal(h.signal) aghos.NotifyShutdownSignal(h.signal)
@ -126,3 +127,41 @@ func newSignalHandler(
return h return h
} }
// writePID writes the PID to the file, if needed. Any errors are reported to
// log.
func (h *signalHandler) writePID() {
if h.pidFile == "" {
return
}
// Use 8, since most PIDs will fit.
data := make([]byte, 0, 8)
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
data = append(data, '\n')
err := maybe.WriteFile(h.pidFile, data, 0o644)
if err != nil {
log.Error("sighdlr: writing pidfile: %s", err)
return
}
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
}
// removePID removes the PID file, if any.
func (h *signalHandler) removePID() {
if h.pidFile == "" {
return
}
err := os.Remove(h.pidFile)
if err != nil {
log.Error("sighdlr: removing pidfile: %s", err)
return
}
log.Debug("sighdlr: removed pid at %q", h.pidFile)
}

View File

@ -14,11 +14,11 @@ import (
type config struct { type config struct {
DNS *dnsConfig `yaml:"dns"` DNS *dnsConfig `yaml:"dns"`
HTTP *httpConfig `yaml:"http"` HTTP *httpConfig `yaml:"http"`
Log *logConfig `yaml:"log"`
// TODO(a.garipov): Use. // TODO(a.garipov): Use.
SchemaVersion int `yaml:"schema_version"` SchemaVersion int `yaml:"schema_version"`
// TODO(a.garipov): Use. // TODO(a.garipov): Use.
DebugPprof bool `yaml:"debug_pprof"` DebugPprof bool `yaml:"debug_pprof"`
Verbose bool `yaml:"verbose"`
} }
const errNoConf errors.Error = "configuration not found" const errNoConf errors.Error = "configuration not found"
@ -41,6 +41,9 @@ func (c *config) validate() (err error) {
}, { }, {
validate: c.HTTP.validate, validate: c.HTTP.validate,
name: "http", name: "http",
}, {
validate: c.Log.validate,
name: "log",
}} }}
for _, v := range validators { for _, v := range validators {
@ -54,8 +57,6 @@ func (c *config) validate() (err error) {
} }
// dnsConfig is the on-disk DNS configuration. // dnsConfig is the on-disk DNS configuration.
//
// TODO(a.garipov): Validate.
type dnsConfig struct { type dnsConfig struct {
Addresses []netip.AddrPort `yaml:"addresses"` Addresses []netip.AddrPort `yaml:"addresses"`
BootstrapDNS []string `yaml:"bootstrap_dns"` BootstrapDNS []string `yaml:"bootstrap_dns"`
@ -82,9 +83,8 @@ func (c *dnsConfig) validate() (err error) {
} }
// httpConfig is the on-disk web API configuration. // httpConfig is the on-disk web API configuration.
//
// TODO(a.garipov): Validate.
type httpConfig struct { type httpConfig struct {
// TODO(a.garipov): Document the configuration change.
Addresses []netip.AddrPort `yaml:"addresses"` Addresses []netip.AddrPort `yaml:"addresses"`
SecureAddresses []netip.AddrPort `yaml:"secure_addresses"` SecureAddresses []netip.AddrPort `yaml:"secure_addresses"`
Timeout timeutil.Duration `yaml:"timeout"` Timeout timeutil.Duration `yaml:"timeout"`
@ -104,3 +104,20 @@ func (c *httpConfig) validate() (err error) {
return nil return nil
} }
} }
// logConfig is the on-disk web API configuration.
type logConfig struct {
// TODO(a.garipov): Use.
Verbose bool `yaml:"verbose"`
}
// validate returns an error if the HTTP configuration structure is invalid.
//
// TODO(a.garipov): Add more validations.
func (c *logConfig) validate() (err error) {
if c == nil {
return errNoConf
}
return nil
}

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"fmt" "fmt"
"io/fs" "io/fs"
"net/netip"
"os" "os"
"sync" "sync"
"time" "time"
@ -27,6 +28,8 @@ import (
// Manager handles full and partial changes in the configuration, persisting // Manager handles full and partial changes in the configuration, persisting
// them to disk if necessary. // them to disk if necessary.
//
// TODO(a.garipov): Support missing configs and default values.
type Manager struct { type Manager struct {
// updMu makes sure that at most one reconfiguration is performed at a time. // updMu makes sure that at most one reconfiguration is performed at a time.
// updMu protects all fields below. // updMu protects all fields below.
@ -58,16 +61,27 @@ func Validate(fileName string) (err error) {
return conf.validate() return conf.validate()
} }
// Config contains the configuration parameters for the configuration manager.
type Config struct {
// Frontend is the filesystem with the frontend files.
Frontend fs.FS
// WebAddr is the initial or override address for the Web UI. It is not
// written to the configuration file.
WebAddr netip.AddrPort
// Start is the time of start of AdGuard Home.
Start time.Time
// FileName is the path to the configuration file.
FileName string
}
// New creates a new *Manager that persists changes to the file pointed to by // New creates a new *Manager that persists changes to the file pointed to by
// fileName. It reads the configuration file and populates the service fields. // c.FileName. It reads the configuration file and populates the service
// start is the startup time of AdGuard Home. // fields. c must not be nil.
func New( func New(ctx context.Context, c *Config) (m *Manager, err error) {
ctx context.Context, conf, err := read(c.FileName)
fileName string,
frontend fs.FS,
start time.Time,
) (m *Manager, err error) {
conf, err := read(fileName)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
@ -81,10 +95,10 @@ func New(
m = &Manager{ m = &Manager{
updMu: &sync.RWMutex{}, updMu: &sync.RWMutex{},
current: conf, current: conf,
fileName: fileName, fileName: c.FileName,
} }
err = m.assemble(ctx, conf, frontend, start) err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating config manager: %w", err) return nil, fmt.Errorf("creating config manager: %w", err)
} }
@ -119,6 +133,7 @@ func (m *Manager) assemble(
ctx context.Context, ctx context.Context,
conf *config, conf *config,
frontend fs.FS, frontend fs.FS,
webAddr netip.AddrPort,
start time.Time, start time.Time,
) (err error) { ) (err error) {
dnsConf := &dnssvc.Config{ dnsConf := &dnssvc.Config{
@ -143,6 +158,7 @@ func (m *Manager) assemble(
Start: start, Start: start,
Addresses: conf.HTTP.Addresses, Addresses: conf.HTTP.Addresses,
SecureAddresses: conf.HTTP.SecureAddresses, SecureAddresses: conf.HTTP.SecureAddresses,
OverrideAddress: webAddr,
Timeout: conf.HTTP.Timeout.Duration, Timeout: conf.HTTP.Timeout.Duration,
ForceHTTPS: conf.HTTP.ForceHTTPS, ForceHTTPS: conf.HTTP.ForceHTTPS,
} }
@ -162,7 +178,7 @@ func (m *Manager) write() (err error) {
return fmt.Errorf("encoding: %w", err) return fmt.Errorf("encoding: %w", err)
} }
err = maybe.WriteFile(m.fileName, b, 0o755) err = maybe.WriteFile(m.fileName, b, 0o644)
if err != nil { if err != nil {
return fmt.Errorf("writing: %w", err) return fmt.Errorf("writing: %w", err)
} }

View File

@ -51,6 +51,10 @@ type Config struct {
// Start is the time of start of AdGuard Home. // Start is the time of start of AdGuard Home.
Start time.Time Start time.Time
// OverrideAddress is the initial or override address for the HTTP API. If
// set, it is used instead of [Addresses] and [SecureAddresses].
OverrideAddress netip.AddrPort
// Addresses are the addresses on which to serve the plain HTTP API. // Addresses are the addresses on which to serve the plain HTTP API.
Addresses []netip.AddrPort Addresses []netip.AddrPort
@ -71,13 +75,14 @@ type Config struct {
// Service is the AdGuard Home web service. A nil *Service is a valid // Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing. // [agh.Service] that does nothing.
type Service struct { type Service struct {
confMgr ConfigManager confMgr ConfigManager
frontend fs.FS frontend fs.FS
tls *tls.Config tls *tls.Config
start time.Time start time.Time
servers []*http.Server overrideAddr netip.AddrPort
timeout time.Duration servers []*http.Server
forceHTTPS bool timeout time.Duration
forceHTTPS bool
} }
// New returns a new properly initialized *Service. If c is nil, svc is a nil // New returns a new properly initialized *Service. If c is nil, svc is a nil
@ -91,54 +96,60 @@ func New(c *Config) (svc *Service, err error) {
return nil, nil return nil, nil
} }
frontend, err := fs.Sub(c.Frontend, "build/static")
if err != nil {
return nil, fmt.Errorf("frontend fs: %w", err)
}
svc = &Service{ svc = &Service{
confMgr: c.ConfigManager, confMgr: c.ConfigManager,
frontend: frontend, frontend: c.Frontend,
tls: c.TLS, tls: c.TLS,
start: c.Start, start: c.Start,
timeout: c.Timeout, overrideAddr: c.OverrideAddress,
forceHTTPS: c.ForceHTTPS, timeout: c.Timeout,
forceHTTPS: c.ForceHTTPS,
} }
mux := newMux(svc) mux := newMux(svc)
for _, a := range c.Addresses { if svc.overrideAddr != (netip.AddrPort{}) {
addr := a.String() svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
errLog := log.StdLog("websvc: plain http: "+addr, log.ERROR) } else {
svc.servers = append(svc.servers, &http.Server{ for _, a := range c.Addresses {
Addr: addr, svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
Handler: mux, }
ErrorLog: errLog,
ReadTimeout: c.Timeout,
WriteTimeout: c.Timeout,
IdleTimeout: c.Timeout,
ReadHeaderTimeout: c.Timeout,
})
}
for _, a := range c.SecureAddresses { for _, a := range c.SecureAddresses {
addr := a.String() svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
errLog := log.StdLog("websvc: https: "+addr, log.ERROR) }
svc.servers = append(svc.servers, &http.Server{
Addr: addr,
Handler: mux,
TLSConfig: c.TLS,
ErrorLog: errLog,
ReadTimeout: c.Timeout,
WriteTimeout: c.Timeout,
IdleTimeout: c.Timeout,
ReadHeaderTimeout: c.Timeout,
})
} }
return svc, nil return svc, nil
} }
// newSrv returns a new *http.Server with the given parameters.
func newSrv(
addr netip.AddrPort,
tlsConf *tls.Config,
h http.Handler,
timeout time.Duration,
) (srv *http.Server) {
addrStr := addr.String()
srv = &http.Server{
Addr: addrStr,
Handler: h,
TLSConfig: tlsConf,
ReadTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ReadHeaderTimeout: timeout,
}
if tlsConf == nil {
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
} else {
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
}
return srv
}
// newMux returns a new HTTP request multiplexer for the AdGuard Home web // newMux returns a new HTTP request multiplexer for the AdGuard Home web
// service. // service.
func newMux(svc *Service) (mux *httptreemux.ContextMux) { func newMux(svc *Service) (mux *httptreemux.ContextMux) {
@ -205,23 +216,23 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) {
// ":0" addresses, addrs will not return the actual bound ports until Start is // ":0" addresses, addrs will not return the actual bound ports until Start is
// finished. // finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
for _, srv := range svc.servers { if svc.overrideAddr != (netip.AddrPort{}) {
addrPort, err := netip.ParseAddrPort(srv.Addr) return []netip.AddrPort{svc.overrideAddr}, nil
if err != nil { }
// Technically shouldn't happen, since all servers must have a valid
// address.
panic(fmt.Errorf("websvc: server %q: bad address: %w", srv.Addr, err))
}
// srv.Serve will set TLSConfig to an almost empty value, so, instead of for _, srv := range svc.servers {
// relying only on the nilness of TLSConfig, check the length of the // Use MustParseAddrPort, since no errors should technically happen
// here, because all servers must have a valid address.
addrPort := netip.MustParseAddrPort(srv.Addr)
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
// of relying only on the nilness of TLSConfig, check the length of the
// certificates field as well. // certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, addrPort) addrs = append(addrs, addrPort)
} else { } else {
secureAddrs = append(secureAddrs, addrPort) secureAddrs = append(secureAddrs, addrPort)
} }
} }
return addrs, secureAddrs return addrs, secureAddrs