Pull request: 2552 rm context.TODO() instances

Merge in DNS/adguard-home from 2552-context to master

Closes #2552.

Squashed commit of the following:

commit 3d1cef33da529f4611869c4a0f2f294a3c8afcaf
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 26 19:28:23 2021 +0300

    all: fix docs

commit d08c78cf4b96419b928e73c497768f40c9e47bc2
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 26 19:22:00 2021 +0300

    all: doc changes

commit c2814f4d0025be74f38299e7e66e7c0193b6c15f
Merge: 100a1a09 44c7221a
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 26 19:12:55 2021 +0300

    Merge branch 'master' into 2552-context

commit 100a1a0957bc22bfaccb1693e6b9b1c5cb53ed13
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 26 19:10:03 2021 +0300

    home: imp docs, fix naming

commit 22717abe6c0e4c1016a53ff2fac1689d0762c462
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 26 18:14:07 2021 +0300

    home: improve code quality

commit 5c96f77a2b315e2c1ad4a11cc7a64f61bdba52a3
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 25 20:28:51 2021 +0300

    home: add docs

commit 323fc013a57a5c06ec391003133b12f4eb2721cd
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 25 14:50:11 2021 +0300

    home: rm context.TODO() instances
This commit is contained in:
Eugene Burkov 2021-01-26 19:44:19 +03:00
parent 44c7221ae9
commit c215b82004
9 changed files with 106 additions and 58 deletions

View File

@ -47,7 +47,7 @@ and this project adheres to
- Improved HTTP requests handling and timeouts ([#2343]). - Improved HTTP requests handling and timeouts ([#2343]).
- Our snap package now uses the `core20` image as its base ([#2306]). - Our snap package now uses the `core20` image as its base ([#2306]).
- New build system and various internal improvements ([#2271], [#2276], [#2297], - New build system and various internal improvements ([#2271], [#2276], [#2297],
[#2509]). [#2509], [#2552]).
[#2231]: https://github.com/AdguardTeam/AdGuardHome/issues/2231 [#2231]: https://github.com/AdguardTeam/AdGuardHome/issues/2231
[#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271 [#2271]: https://github.com/AdguardTeam/AdGuardHome/issues/2271
@ -59,6 +59,7 @@ and this project adheres to
[#2391]: https://github.com/AdguardTeam/AdGuardHome/issues/2391 [#2391]: https://github.com/AdguardTeam/AdGuardHome/issues/2391
[#2394]: https://github.com/AdguardTeam/AdGuardHome/issues/2394 [#2394]: https://github.com/AdguardTeam/AdGuardHome/issues/2394
[#2509]: https://github.com/AdguardTeam/AdGuardHome/issues/2509 [#2509]: https://github.com/AdguardTeam/AdGuardHome/issues/2509
[#2552]: https://github.com/AdguardTeam/AdGuardHome/issues/2552
[#2589]: https://github.com/AdguardTeam/AdGuardHome/issues/2589 [#2589]: https://github.com/AdguardTeam/AdGuardHome/issues/2589
### Deprecated ### Deprecated

View File

@ -13,6 +13,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/util"
@ -268,6 +269,9 @@ func copyInstallSettings(dst, src *configuration) {
dst.DNS.Port = src.DNS.Port dst.DNS.Port = src.DNS.Port
} }
// shutdownTimeout is the timeout for shutting HTTP server down operation.
const shutdownTimeout = 5 * time.Second
// Apply new configuration, start DNS server, restart Web server // Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
newSettings := applyConfigReq{} newSettings := applyConfigReq{}
@ -320,6 +324,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHost = newSettings.DNS.IP config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port config.DNS.Port = newSettings.DNS.Port
// TODO(e.burkov): StartMods() should be put in a separate goroutine at
// the moment we'll allow setting up TLS in the initial configuration or
// the configuration itself will use HTTPS protocol, because the
// underlying functions potentially restart the HTTPS server.
err = StartMods() err = StartMods()
if err != nil { if err != nil {
Context.firstRun = true Context.firstRun = true
@ -351,16 +359,22 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
f.Flush() f.Flush()
} }
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // The Shutdown() method of (*http.Server) needs to be called in a
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // separate goroutine, because it waits until all requests are handled
// and will be blocked by it's own caller.
if restartHTTP { if restartHTTP {
go func() { ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
_ = web.httpServer.Shutdown(context.TODO())
}() shut := func(srv *http.Server) {
defer cancel()
err := srv.Shutdown(ctx)
if err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
}
}
go shut(web.httpServer)
if web.httpServerBeta != nil { if web.httpServerBeta != nil {
go func() { go shut(web.httpServerBeta)
_ = web.httpServerBeta.Shutdown(context.TODO())
}()
} }
} }
} }

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
@ -90,7 +91,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
} }
} }
// Perform an update procedure to the latest available version // handleUpdate performs an update to the latest available version procedure.
func handleUpdate(w http.ResponseWriter, _ *http.Request) { func handleUpdate(w http.ResponseWriter, _ *http.Request) {
if Context.updater.NewVersion() == "" { if Context.updater.NewVersion() == "" {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now") httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
@ -108,7 +109,13 @@ func handleUpdate(w http.ResponseWriter, _ *http.Request) {
f.Flush() f.Flush()
} }
go finishUpdate() // The background context is used because the underlying functions wrap
// it with timeout and shut down the server, which handles current
// request. It also should be done in a separate goroutine due to the
// same reason.
go func() {
finishUpdate(context.Background())
}()
} }
// versionResponse is the response for /control/version.json endpoint. // versionResponse is the response for /control/version.json endpoint.
@ -140,10 +147,10 @@ func (vr *versionResponse) confirmAutoUpdate() {
} }
} }
// Complete an update procedure // finishUpdate completes an update procedure.
func finishUpdate() { func finishUpdate(ctx context.Context) {
log.Info("Stopping all tasks") log.Info("Stopping all tasks")
cleanup() cleanup(ctx)
cleanupAlways() cleanupAlways()
exeName := "AdGuardHome" exeName := "AdGuardHome"

View File

@ -108,7 +108,7 @@ func Main() {
Context.tls.Reload() Context.tls.Reload()
default: default:
cleanup() cleanup(context.Background())
cleanupAlways() cleanupAlways()
os.Exit(0) os.Exit(0)
} }
@ -334,7 +334,7 @@ func run(args options) {
select {} select {}
} }
// StartMods - initialize and start DNS after installation // StartMods initializes and starts the DNS server after installation.
func StartMods() error { func StartMods() error {
err := initDNSServer() err := initDNSServer()
if err != nil { if err != nil {
@ -501,11 +501,12 @@ func configureLogger(args options) {
} }
} }
func cleanup() { // cleanup stops and resets all the modules.
func cleanup(ctx context.Context) {
log.Info("Stopping AdGuard Home") log.Info("Stopping AdGuard Home")
if Context.web != nil { if Context.web != nil {
Context.web.Close() Context.web.Close(ctx)
Context.web = nil Context.web = nil
} }
if Context.auth != nil { if Context.auth != nil {

View File

@ -186,6 +186,6 @@ func TestHome(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
cleanup() cleanup(context.Background())
cleanupAlways() cleanupAlways()
} }

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
@ -92,7 +93,7 @@ func (t *TLSMod) setCertFileTime() {
t.certLastMod = fi.ModTime().UTC() t.certLastMod = fi.ModTime().UTC()
} }
// Start - start the module // Start updates the configuration of TLSMod and starts it.
func (t *TLSMod) Start() { func (t *TLSMod) Start() {
if !tlsWebHandlersRegistered { if !tlsWebHandlersRegistered {
tlsWebHandlersRegistered = true tlsWebHandlersRegistered = true
@ -102,10 +103,14 @@ func (t *TLSMod) Start() {
t.confLock.Lock() t.confLock.Lock()
tlsConf := t.conf tlsConf := t.conf
t.confLock.Unlock() t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf)
// The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which
// handles current request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
} }
// Reload - reload certificate file // Reload updates the configuration of TLSMod and restarts it.
func (t *TLSMod) Reload() { func (t *TLSMod) Reload() {
t.confLock.Lock() t.confLock.Lock()
tlsConf := t.conf tlsConf := t.conf
@ -139,7 +144,10 @@ func (t *TLSMod) Reload() {
t.confLock.Lock() t.confLock.Lock()
tlsConf = t.conf tlsConf = t.conf
t.confLock.Unlock() t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf) // The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which
// handles current request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
} }
// Set certificate and private key data // Set certificate and private key data
@ -296,11 +304,13 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
f.Flush() f.Flush()
} }
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // The background context is used because the TLSConfigChanged wraps
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // context with timeout on its own and shuts down the server, which
// handles current request. It is also should be done in a separate
// goroutine due to the same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
Context.web.TLSConfigChanged(data) Context.web.TLSConfigChanged(context.Background(), data)
}() }()
} }
} }

View File

@ -122,8 +122,9 @@ func WebCheckPortAvailable(port int) bool {
return true return true
} }
// TLSConfigChanged - called when TLS configuration has changed // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) { // if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
log.Debug("Web: applying new TLS configuration") log.Debug("Web: applying new TLS configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0) web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
@ -143,7 +144,12 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
web.httpsServer.cond.L.Lock() web.httpsServer.cond.L.Lock()
if web.httpsServer.server != nil { if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO()) ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
err = web.httpsServer.server.Shutdown(ctx)
cancel()
if err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
}
} }
web.httpsServer.enabled = enabled web.httpsServer.enabled = enabled
web.httpsServer.cert = cert web.httpsServer.cert = cert
@ -198,22 +204,28 @@ func (web *Web) Start() {
} }
} }
// Close - stop HTTP server, possibly waiting for all active connections to be closed // Close gracefully shuts down the HTTP servers.
func (web *Web) Close() { func (web *Web) Close(ctx context.Context) {
log.Info("Stopping HTTP server...") log.Info("Stopping HTTP server...")
web.httpsServer.cond.L.Lock() web.httpsServer.cond.L.Lock()
web.httpsServer.shutdown = true web.httpsServer.shutdown = true
web.httpsServer.cond.L.Unlock() web.httpsServer.cond.L.Unlock()
if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO()) shut := func(srv *http.Server) {
if srv == nil {
return
} }
if web.httpServer != nil { ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
_ = web.httpServer.Shutdown(context.TODO()) defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
} }
if web.httpServerBeta != nil {
_ = web.httpServerBeta.Shutdown(context.TODO())
} }
shut(web.httpsServer.server)
shut(web.httpServer)
shut(web.httpServerBeta)
log.Info("Stopped HTTP server") log.Info("Stopped HTTP server")
} }

View File

@ -35,19 +35,20 @@ type Whois struct {
ipAddrs cache.Cache ipAddrs cache.Cache
} }
// Create module context // initWhois creates the Whois module context.
func initWhois(clients *clientsContainer) *Whois { func initWhois(clients *clientsContainer) *Whois {
w := Whois{} w := Whois{
w.timeoutMsec = 5000 timeoutMsec: 5000,
w.clients = clients clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
ipChan: make(chan net.IP, 255),
}
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
w.ipAddrs = cache.New(cconf)
w.ipChan = make(chan net.IP, 255)
go w.workerLoop() go w.workerLoop()
return &w return &w
} }
@ -120,12 +121,12 @@ func whoisParse(data string) map[string]string {
const MaxConnReadSize = 64 * 1024 const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response // Send request to a server and receive the response
func (w *Whois) query(target, serverAddr string) (string, error) { func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) {
addr, _, _ := net.SplitHostPort(serverAddr) addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" { if addr == "whois.arin.net" {
target = "n + " + target target = "n + " + target
} }
conn, err := customDialContext(context.TODO(), "tcp", serverAddr) conn, err := customDialContext(ctx, "tcp", serverAddr)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -153,11 +154,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
} }
// Query WHOIS servers (handle redirects) // Query WHOIS servers (handle redirects)
func (w *Whois) queryAll(target string) (string, error) { func (w *Whois) queryAll(ctx context.Context, target string) (string, error) {
server := net.JoinHostPort(defaultServer, defaultPort) server := net.JoinHostPort(defaultServer, defaultPort)
const maxRedirects = 5 const maxRedirects = 5
for i := 0; i != maxRedirects; i++ { for i := 0; i != maxRedirects; i++ {
resp, err := w.query(target, server) resp, err := w.query(ctx, target, server)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -183,9 +184,9 @@ func (w *Whois) queryAll(target string) (string, error) {
} }
// Request WHOIS information // Request WHOIS information
func (w *Whois) process(ip net.IP) [][]string { func (w *Whois) process(ctx context.Context, ip net.IP) [][]string {
data := [][]string{} data := [][]string{}
resp, err := w.queryAll(ip.String()) resp, err := w.queryAll(ctx, ip.String())
if err != nil { if err != nil {
log.Debug("Whois: error: %s IP:%s", err, ip) log.Debug("Whois: error: %s IP:%s", err, ip)
return data return data
@ -232,12 +233,13 @@ func (w *Whois) Begin(ip net.IP) {
} }
} }
// Get IP address from channel; get WHOIS info; associate info with a client // workerLoop processes the IP addresses it got from the channel and associates
// the retrieving WHOIS info with a client.
func (w *Whois) workerLoop() { func (w *Whois) workerLoop() {
for { for {
ip := <-w.ipChan ip := <-w.ipChan
info := w.process(ip) info := w.process(context.Background(), ip)
if len(info) == 0 { if len(info) == 0 {
continue continue
} }

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@ -19,7 +20,7 @@ func TestWhois(t *testing.T) {
assert.Nil(t, prepareTestDNSServer()) assert.Nil(t, prepareTestDNSServer())
w := Whois{timeoutMsec: 5000} w := Whois{timeoutMsec: 5000}
resp, err := w.queryAll("8.8.8.8") resp, err := w.queryAll(context.Background(), "8.8.8.8")
assert.Nil(t, err) assert.Nil(t, err)
m := whoisParse(resp) m := whoisParse(resp)
assert.Equal(t, "Google LLC", m["orgname"]) assert.Equal(t, "Google LLC", m["orgname"])