dnsforward: fix upstream check endpoint
This commit is contained in:
parent
f8fe9bfc8b
commit
deedc490e1
|
@ -483,7 +483,7 @@ func validateUpstreamConfig(conf []string) (err error) {
|
|||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, domains)
|
||||
_, err = validateUpstream(addr, len(domains) > 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
|
@ -556,10 +556,10 @@ var protocols = []string{
|
|||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
func validateUpstream(u string, specific bool) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
if u == "#" && specific {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
|
@ -701,52 +701,8 @@ func (err domainSpecificTestError) Error() (msg string) {
|
|||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
||||
// upstreams are exchanging correctly. It returns a map where key is an
|
||||
// upstream address and value is "OK", if the upstream exchanges correctly, or
|
||||
// text of the error.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (result map[string]string) {
|
||||
result = map[string]string{}
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
specific := len(domains) > 0
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
result[upstreamAddr] = err.Error()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
||||
if err != nil {
|
||||
result[upstreamAddr] = err.Error()
|
||||
} else {
|
||||
result[upstreamAddr] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
||||
// system hosts files. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
|
@ -755,18 +711,21 @@ func (s *Server) checkUpstreamAddr(
|
|||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
useDefault, err := validateUpstream(addr, specific)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", addr)
|
||||
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
|
@ -776,16 +735,113 @@ func (s *Server) checkUpstreamAddr(
|
|||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
u, err := upstream.AddressToUpstream(addr, &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
}
|
||||
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
||||
// checkResult is a result of checking an upstream server.
|
||||
type checkResult = struct {
|
||||
// status is an error message if the upstream server is not working.
|
||||
status error
|
||||
|
||||
// ups is the upstream server address as given in the request. It may
|
||||
// appear a domain-specific upstream line if it isn't correct itself.
|
||||
ups string
|
||||
}
|
||||
|
||||
// checkDNS parses an upstream configuration line using opts and checks if the
|
||||
// specified upstreams are working using check. addWG is decremented when the
|
||||
// expected number of results is added to resWG, then results are sent to resCh.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
addWG *sync.WaitGroup,
|
||||
resWG *sync.WaitGroup,
|
||||
resCh chan checkResult,
|
||||
) {
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
resWG.Add(1)
|
||||
addWG.Done()
|
||||
|
||||
resCh <- checkResult{
|
||||
ups: line,
|
||||
status: fmt.Errorf("wrong upstream format: %w", err),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resWG.Add(len(upstreams))
|
||||
addWG.Done()
|
||||
|
||||
specific := len(domains) > 0
|
||||
for _, ups := range upstreams {
|
||||
cr := checkResult{ups: ups}
|
||||
|
||||
checkErr := s.checkUpstreamAddr(ups, specific, opts, check)
|
||||
if checkErr != nil {
|
||||
cr.status = checkErr
|
||||
}
|
||||
|
||||
resCh <- cr
|
||||
}
|
||||
}
|
||||
|
||||
// check returns the mapping of upstream addresses to their check results.
|
||||
func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[string]string) {
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
result = map[string]string{}
|
||||
resCh := make(chan checkResult)
|
||||
resWG := &sync.WaitGroup{}
|
||||
go func() {
|
||||
for res := range resCh {
|
||||
if res.status != nil {
|
||||
result[res.ups] = res.status.Error()
|
||||
} else {
|
||||
result[res.ups] = "OK"
|
||||
}
|
||||
resWG.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
// addWG is used to wait for all goroutines to count the expected number of
|
||||
// results and to add it to resWG.
|
||||
addWG := &sync.WaitGroup{}
|
||||
for _, ups := range req.Upstreams {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh)
|
||||
addWG.Add(1)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, addWG, resWG, resCh)
|
||||
addWG.Add(1)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, addWG, resWG, resCh)
|
||||
addWG.Add(1)
|
||||
}
|
||||
|
||||
addWG.Wait()
|
||||
resWG.Wait()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||
// endpoint.
|
||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -797,59 +853,18 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
bootstrapAddrs := stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
if len(bootstrapAddrs) == 0 {
|
||||
bootstrapAddrs = defaultBootstrap
|
||||
}
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Bootstrap: bootstrapAddrs,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
m := &sync.Map{}
|
||||
|
||||
// TODO(s.chzhen): Separate to a different structure/file.
|
||||
worker := func(upstreamLine string, check healthCheckFunc) {
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
res := s.checkDNS(upstreamLine, opts, check)
|
||||
for ups, status := range res {
|
||||
m.Store(ups, status)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go worker(ups, checkDNSUpstreamExc)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go worker(ups, checkDNSUpstreamExc)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go worker(ups, checkPrivateUpstreamExc)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
result := map[string]string{}
|
||||
m.Range(func(k, v any) bool {
|
||||
ups := k.(string)
|
||||
status := v.(string)
|
||||
|
||||
result[ups] = status
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
aghhttp.WriteJSONResponseOK(w, r, s.check(req, opts))
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
|
|
Loading…
Reference in New Issue