Pull request 1814: AG-21291-web-races

Merge in DNS/adguard-home from AG-21291-web-races to master

Squashed commit of the following:

commit 1134013f928aa5e186db3b6d0450e425cb053e9c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Apr 11 16:52:52 2023 +0300

    home: fix web api races
This commit is contained in:
Ainar Garipov 2023-04-11 17:22:51 +03:00
parent 230d7b8c17
commit 0376afb38e
5 changed files with 45 additions and 38 deletions

View File

@ -332,13 +332,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return false
}
var serveHTTP3 bool
var portHTTPS int
var (
forceHTTPS bool
serveHTTP3 bool
portHTTPS int
)
func() {
config.RLock()
defer config.RUnlock()
serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS
forceHTTPS = config.TLS.ForceHTTPS && config.TLS.Enabled && config.TLS.PortHTTPS != 0
}()
respHdr := w.Header()
@ -354,10 +358,10 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
respHdr.Set(httphdr.AltSvc, altSvc)
}
if r.TLS == nil && web.forceHTTPS {
if r.TLS == nil && forceHTTPS {
hostPort := host
if port := web.conf.PortHTTPS; port != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, port)
if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, portHTTPS)
}
httpsURL := &url.URL{

View File

@ -39,7 +39,7 @@ type getAddrsResponse struct {
}
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponse{
Version: version.Version(),
@ -167,7 +167,7 @@ func (req *checkConfReq) validateDNS(
}
// handleInstallCheckConfig handles the /check_config endpoint.
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
req := &checkConfReq{}
err := json.NewDecoder(r.Body).Decode(req)
@ -375,7 +375,7 @@ func shutdownSrv3(srv *http3.Server) {
const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -503,7 +503,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return req, restartHTTP, err
}
func (web *Web) registerInstallHandlers() {
func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))

View File

@ -58,7 +58,7 @@ type homeContext struct {
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
web *webAPI // Web (HTTP, HTTPS) module
tls *tlsManager // TLS module
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
@ -387,7 +387,7 @@ func checkPorts() (err error) {
return nil
}
func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
var clientFS fs.FS
if opts.localFrontend {
log.Info("warning: using local frontend files")
@ -414,7 +414,7 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
serveHTTP3: config.DNS.ServeHTTP3,
}
web = newWeb(&webConf)
web = newWebAPI(&webConf)
if web == nil {
return nil, fmt.Errorf("initializing web: %w", err)
}
@ -533,7 +533,7 @@ func run(opts options, clientBuildFS fs.FS) {
}
}
Context.web.Start()
Context.web.start()
// wait indefinitely for other go-routines to complete their job
select {}
@ -713,7 +713,7 @@ func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home")
if Context.web != nil {
Context.web.Close(ctx)
Context.web.close(ctx)
Context.web = nil
}
if Context.auth != nil {

View File

@ -108,7 +108,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.tlsConfigChanged(context.Background(), tlsConf)
}
// reload updates the configuration and restarts t.
@ -156,7 +156,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.tlsConfigChanged(context.Background(), tlsConf)
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
@ -454,7 +454,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason.
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}

View File

@ -35,9 +35,8 @@ const (
type webConfig struct {
clientFS fs.FS
BindHost netip.Addr
BindPort int
PortHTTPS int
BindHost netip.Addr
BindPort int
// ReadTimeout is an option to pass to http.Server for setting an
// appropriate field.
@ -72,8 +71,8 @@ type httpsServer struct {
enabled bool
}
// Web is the web UI and API server.
type Web struct {
// webAPI is the web UI and API server.
type webAPI struct {
conf *webConfig
// TODO(a.garipov): Refactor all these servers.
@ -82,15 +81,13 @@ type Web struct {
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
// [Web.http3Server] must also not be nil.
httpsServer httpsServer
forceHTTPS bool
}
// newWeb creates a new instance of the web UI and API server.
func newWeb(conf *webConfig) (w *Web) {
// newWebAPI creates a new instance of the web UI and API server.
func newWebAPI(conf *webConfig) (w *webAPI) {
log.Info("web: initializing")
w = &Web{
w = &webAPI{
conf: conf,
}
@ -125,12 +122,10 @@ func webCheckPortAvailable(port int) (ok bool) {
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
}
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
log.Debug("web: applying new tls configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 &&
@ -161,8 +156,8 @@ func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings)
web.httpsServer.cond.L.Unlock()
}
// Start - start serving HTTP requests
func (web *Web) Start() {
// start - start serving HTTP requests
func (web *webAPI) start() {
log.Println("AdGuard Home is available at the following addresses:")
// for https, we have a separate goroutine loop
@ -203,8 +198,8 @@ func (web *Web) Start() {
}
}
// Close gracefully shuts down the HTTP servers.
func (web *Web) Close(ctx context.Context) {
// close gracefully shuts down the HTTP servers.
func (web *webAPI) close(ctx context.Context) {
log.Info("stopping http server...")
web.httpsServer.cond.L.Lock()
@ -222,7 +217,7 @@ func (web *Web) Close(ctx context.Context) {
log.Info("stopped http server")
}
func (web *Web) tlsServerLoop() {
func (web *webAPI) tlsServerLoop() {
for {
web.httpsServer.cond.L.Lock()
if web.httpsServer.inShutdown {
@ -241,7 +236,15 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS)
var portHTTPS int
func() {
config.RLock()
defer config.RUnlock()
portHTTPS = config.TLS.PortHTTPS
}()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: addr,
@ -272,7 +275,7 @@ func (web *Web) tlsServerLoop() {
}
}
func (web *Web) mustStartHTTP3(address string) {
func (web *webAPI) mustStartHTTP3(address string) {
defer log.OnPanic("web: http3")
web.httpsServer.server3 = &http3.Server{