From deedc490e169b77cd4b4e7f62b59df3943ad2e8c Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 8 Nov 2023 17:51:34 +0300 Subject: [PATCH] dnsforward: fix upstream check endpoint --- internal/dnsforward/http.go | 223 +++++++++++++++++++----------------- 1 file changed, 119 insertions(+), 104 deletions(-) diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 99ced3ef..c3704a69 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -483,7 +483,7 @@ func validateUpstreamConfig(conf []string) (err error) { } for _, addr := range ups { - _, err = validateUpstream(addr, domains) + _, err = validateUpstream(addr, len(domains) > 0) if err != nil { return fmt.Errorf("validating upstream %q: %w", addr, err) } @@ -556,10 +556,10 @@ var protocols = []string{ // domain-specific and is configured to point at the default upstream server // which is validated separately. The upstream is considered domain-specific // only if domains is at least not nil. -func validateUpstream(u string, domains []string) (useDefault bool, err error) { +func validateUpstream(u string, specific bool) (useDefault bool, err error) { // The special server address '#' means that default server must be used. - if useDefault = u == "#" && domains != nil; useDefault { - return useDefault, nil + if u == "#" && specific { + return true, nil } // Check if the upstream has a valid protocol prefix. @@ -701,52 +701,8 @@ func (err domainSpecificTestError) Error() (msg string) { return fmt.Sprintf("WARNING: %s", err.error) } -// checkDNS parses line, creates DNS upstreams using opts, and checks if the -// upstreams are exchanging correctly. It returns a map where key is an -// upstream address and value is "OK", if the upstream exchanges correctly, or -// text of the error. -func (s *Server) checkDNS( - line string, - opts *upstream.Options, - check healthCheckFunc, -) (result map[string]string) { - result = map[string]string{} - upstreams, domains, err := separateUpstream(line) - if err != nil { - return nil - } - - specific := len(domains) > 0 - - for _, upstreamAddr := range upstreams { - var useDefault bool - useDefault, err = validateUpstream(upstreamAddr, domains) - if err != nil { - err = fmt.Errorf("wrong upstream format: %w", err) - result[upstreamAddr] = err.Error() - - continue - } - - if useDefault { - continue - } - - log.Debug("dnsforward: checking if upstream %q works", upstreamAddr) - - err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check) - if err != nil { - result[upstreamAddr] = err.Error() - } else { - result[upstreamAddr] = "OK" - } - } - - return result -} - // checkUpstreamAddr creates the DNS upstream using opts and information from -// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It +// system hosts files. Checks if the DNS upstream exchanges correctly. It // returns an error if addr is not valid DNS upstream address or the upstream // is not exchanging correctly. func (s *Server) checkUpstreamAddr( @@ -755,18 +711,21 @@ func (s *Server) checkUpstreamAddr( opts *upstream.Options, check healthCheckFunc, ) (err error) { + useDefault, err := validateUpstream(addr, specific) + if err != nil { + return fmt.Errorf("wrong upstream format: %w", err) + } else if useDefault { + return nil + } + + log.Debug("dnsforward: checking if upstream %q works", addr) + defer func() { if err != nil && specific { err = domainSpecificTestError{error: err} } }() - opts = &upstream.Options{ - Bootstrap: opts.Bootstrap, - Timeout: opts.Timeout, - PreferIPv6: opts.PreferIPv6, - } - // dnsFilter can be nil during application update. if s.dnsFilter != nil { recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr)) @@ -776,16 +735,113 @@ func (s *Server) checkUpstreamAddr( sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6) } - u, err := upstream.AddressToUpstream(addr, opts) + u, err := upstream.AddressToUpstream(addr, &upstream.Options{ + Bootstrap: opts.Bootstrap, + Timeout: opts.Timeout, + PreferIPv6: opts.PreferIPv6, + }) if err != nil { return fmt.Errorf("creating upstream for %q: %w", addr, err) } - defer func() { err = errors.WithDeferred(err, u.Close()) }() return check(u) } +// checkResult is a result of checking an upstream server. +type checkResult = struct { + // status is an error message if the upstream server is not working. + status error + + // ups is the upstream server address as given in the request. It may + // appear a domain-specific upstream line if it isn't correct itself. + ups string +} + +// checkDNS parses an upstream configuration line using opts and checks if the +// specified upstreams are working using check. addWG is decremented when the +// expected number of results is added to resWG, then results are sent to resCh. +func (s *Server) checkDNS( + line string, + opts *upstream.Options, + check healthCheckFunc, + addWG *sync.WaitGroup, + resWG *sync.WaitGroup, + resCh chan checkResult, +) { + defer log.OnPanic("dnsforward: checking upstreams") + + upstreams, domains, err := separateUpstream(line) + if err != nil { + resWG.Add(1) + addWG.Done() + + resCh <- checkResult{ + ups: line, + status: fmt.Errorf("wrong upstream format: %w", err), + } + + return + } + + resWG.Add(len(upstreams)) + addWG.Done() + + specific := len(domains) > 0 + for _, ups := range upstreams { + cr := checkResult{ups: ups} + + checkErr := s.checkUpstreamAddr(ups, specific, opts, check) + if checkErr != nil { + cr.status = checkErr + } + + resCh <- cr + } +} + +// check returns the mapping of upstream addresses to their check results. +func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[string]string) { + req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty) + req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty) + req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty) + + result = map[string]string{} + resCh := make(chan checkResult) + resWG := &sync.WaitGroup{} + go func() { + for res := range resCh { + if res.status != nil { + result[res.ups] = res.status.Error() + } else { + result[res.ups] = "OK" + } + resWG.Done() + } + }() + + // addWG is used to wait for all goroutines to count the expected number of + // results and to add it to resWG. + addWG := &sync.WaitGroup{} + for _, ups := range req.Upstreams { + go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh) + addWG.Add(1) + } + for _, ups := range req.FallbackDNS { + go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh) + addWG.Add(1) + } + for _, ups := range req.PrivateUpstreams { + go s.checkDNS(ups, opts, checkPrivateUpstreamExc, addWG, resWG, resCh) + addWG.Add(1) + } + + addWG.Wait() + resWG.Wait() + + return result +} + // handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns // endpoint. func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { @@ -797,59 +853,18 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { return } - req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty) - req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty) - req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty) + bootstrapAddrs := stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty) + if len(bootstrapAddrs) == 0 { + bootstrapAddrs = defaultBootstrap + } opts := &upstream.Options{ - Bootstrap: req.BootstrapDNS, + Bootstrap: bootstrapAddrs, Timeout: s.conf.UpstreamTimeout, PreferIPv6: s.conf.BootstrapPreferIPv6, } - if len(opts.Bootstrap) == 0 { - opts.Bootstrap = defaultBootstrap - } - wg := &sync.WaitGroup{} - m := &sync.Map{} - - // TODO(s.chzhen): Separate to a different structure/file. - worker := func(upstreamLine string, check healthCheckFunc) { - defer log.OnPanic("dnsforward: checking upstreams") - - res := s.checkDNS(upstreamLine, opts, check) - for ups, status := range res { - m.Store(ups, status) - } - - wg.Done() - } - - wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)) - - for _, ups := range req.Upstreams { - go worker(ups, checkDNSUpstreamExc) - } - for _, ups := range req.FallbackDNS { - go worker(ups, checkDNSUpstreamExc) - } - for _, ups := range req.PrivateUpstreams { - go worker(ups, checkPrivateUpstreamExc) - } - - wg.Wait() - - result := map[string]string{} - m.Range(func(k, v any) bool { - ups := k.(string) - status := v.(string) - - result[ups] = status - - return true - }) - - aghhttp.WriteJSONResponseOK(w, r, result) + aghhttp.WriteJSONResponseOK(w, r, s.check(req, opts)) } // handleCacheClear is the handler for the POST /control/cache_clear HTTP API.