diff --git a/CHANGELOG.md b/CHANGELOG.md index 95bfb899..abd261d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ See also the [v0.107.20 GitHub milestone][ms-v0.107.20]. ### Fixed +- Slow upstream checks making the API unresponsive ([#5193]). - The TLS initialization errors preventing AdGuard Home from starting ([#5189]). Instead, AdGuard Home disables encryption and shows an error message on the encryption settings page in the UI, which was the intended previous behavior. @@ -44,6 +45,7 @@ See also the [v0.107.20 GitHub milestone][ms-v0.107.20]. [#4944]: https://github.com/AdguardTeam/AdGuardHome/issues/4944 [#5189]: https://github.com/AdguardTeam/AdGuardHome/issues/5189 [#5190]: https://github.com/AdguardTeam/AdGuardHome/issues/5190 +[#5193]: https://github.com/AdguardTeam/AdGuardHome/issues/5193 diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 1baf14b2..5668573b 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -566,6 +566,11 @@ type domainSpecificTestError struct { error } +// Error implements the [error] interface for domainSpecificTestError. +func (err domainSpecificTestError) Error() (msg string) { + return fmt.Sprintf("WARNING: %s", err.error) +} + // checkDNS checks the upstream server defined by upstreamConfigStr using // healthCheck for actually exchange messages. It uses bootstrap to resolve the // upstream's address. @@ -632,41 +637,45 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { result := map[string]string{} bootstraps := req.BootstrapDNS - timeout := s.conf.UpstreamTimeout - for _, host := range req.Upstreams { - err = checkDNS(host, bootstraps, timeout, checkDNSUpstreamExc) - if err != nil { - log.Info("%v", err) - result[host] = err.Error() - if _, ok := err.(domainSpecificTestError); ok { - result[host] = fmt.Sprintf("WARNING: %s", result[host]) - } - continue - } - - result[host] = "OK" + type upsCheckResult = struct { + res string + host string } - for _, host := range req.PrivateUpstreams { - err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc) - if err != nil { - log.Info("%v", err) - // TODO(e.burkov): If passed upstream have already written an error - // above, we rewriting the error for it. These cases should be - // handled properly instead. - result[host] = err.Error() - if _, ok := err.(domainSpecificTestError); ok { - result[host] = fmt.Sprintf("WARNING: %s", result[host]) - } + upsNum := len(req.Upstreams) + len(req.PrivateUpstreams) + resCh := make(chan upsCheckResult, upsNum) - continue + checkUps := func(ups string, healthCheck healthCheckFunc) { + res := upsCheckResult{ + host: ups, } + defer func() { resCh <- res }() - result[host] = "OK" + checkErr := checkDNS(ups, bootstraps, timeout, healthCheck) + if checkErr != nil { + res.res = checkErr.Error() + } else { + res.res = "OK" + } } + for _, ups := range req.Upstreams { + go checkUps(ups, checkDNSUpstreamExc) + } + for _, ups := range req.PrivateUpstreams { + go checkUps(ups, checkPrivateUpstreamExc) + } + + for i := 0; i < upsNum; i++ { + pair := <-resCh + // TODO(e.burkov): The upstreams used for both common and private + // resolving should be reported separately. + result[pair.host] = pair.res + } + close(resCh) + _ = aghhttp.WriteJSONResponse(w, r, result) } diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 64ba21c4..5e0b8018 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -7,16 +7,20 @@ import ( "net" "net/http" "net/http/httptest" + "net/netip" + "net/url" "os" "path/filepath" "strings" "testing" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -392,3 +396,141 @@ func TestValidateUpstreamsPrivate(t *testing.T) { }) } } + +func newLocalUpstreamListener(t *testing.T, port int, handler dns.Handler) (real net.Addr) { + startCh := make(chan struct{}) + upsSrv := &dns.Server{ + Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(port)).String(), + Net: "tcp", + Handler: handler, + NotifyStartedFunc: func() { close(startCh) }, + } + go func() { + t := testutil.PanicT{} + + err := upsSrv.ListenAndServe() + require.NoError(t, err) + }() + <-startCh + testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown) + + return upsSrv.Listener.Addr() +} + +func TestServer_handleTestUpstreaDNS(t *testing.T) { + goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + err := w.WriteMsg(new(dns.Msg).SetReply(m)) + require.NoError(testutil.PanicT{}, err) + }) + badHandler := dns.HandlerFunc(func(w dns.ResponseWriter, _ *dns.Msg) { + err := w.WriteMsg(new(dns.Msg)) + require.NoError(testutil.PanicT{}, err) + }) + + goodUps := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, goodHandler).String(), + }).String() + badUps := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, badHandler).String(), + }).String() + + const upsTimeout = 100 * time.Millisecond + + srv := createTestServer(t, &filtering.Config{}, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + UpstreamTimeout: upsTimeout, + }, nil) + startDeferStop(t, srv) + + testCases := []struct { + body map[string]any + wantResp map[string]any + name string + }{{ + body: map[string]any{ + "upstream_dns": []string{goodUps}, + }, + wantResp: map[string]any{ + goodUps: "OK", + }, + name: "success", + }, { + body: map[string]any{ + "upstream_dns": []string{badUps}, + }, + wantResp: map[string]any{ + badUps: `upstream "` + badUps + `" fails to exchange: ` + + `couldn't communicate with upstream: dns: id mismatch`, + }, + name: "broken", + }, { + body: map[string]any{ + "upstream_dns": []string{goodUps, badUps}, + }, + wantResp: map[string]any{ + goodUps: "OK", + badUps: `upstream "` + badUps + `" fails to exchange: ` + + `couldn't communicate with upstream: dns: id mismatch`, + }, + name: "both", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody, err := json.Marshal(tc.body) + require.NoError(t, err) + + w := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody)) + require.NoError(t, err) + + srv.handleTestUpstreamDNS(w, r) + require.Equal(t, http.StatusOK, w.Code) + + resp := map[string]any{} + err = json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + + assert.Equal(t, tc.wantResp, resp) + }) + } + + t.Run("timeout", func(t *testing.T) { + slowHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + time.Sleep(upsTimeout * 2) + writeErr := w.WriteMsg(new(dns.Msg).SetReply(m)) + require.NoError(testutil.PanicT{}, writeErr) + }) + sleepyUps := (&url.URL{ + Scheme: "tcp", + Host: newLocalUpstreamListener(t, 0, slowHandler).String(), + }).String() + + req := map[string]any{ + "upstream_dns": []string{sleepyUps}, + } + reqBody, err := json.Marshal(req) + require.NoError(t, err) + + w := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody)) + require.NoError(t, err) + + srv.handleTestUpstreamDNS(w, r) + require.Equal(t, http.StatusOK, w.Code) + + resp := map[string]any{} + err = json.NewDecoder(w.Body).Decode(&resp) + require.NoError(t, err) + + require.Contains(t, resp, sleepyUps) + require.IsType(t, "", resp[sleepyUps]) + sleepyRes, _ := resp[sleepyUps].(string) + + // TODO(e.burkov): Improve the format of an error in dnsproxy. + assert.True(t, strings.HasSuffix(sleepyRes, "i/o timeout")) + }) +}