dnsforward: add ups checker

This commit is contained in:
Eugene Burkov 2023-11-21 16:12:20 +03:00
parent 09db7d2a60
commit e2612e0e6f
3 changed files with 262 additions and 154 deletions

View File

@ -6,7 +6,6 @@ import (
"io"
"net/http"
"net/netip"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@ -578,34 +577,34 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
defer closeBoots(boots)
wg := &sync.WaitGroup{}
m := &sync.Map{}
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
for _, ups := range req.Upstreams {
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
}
for _, ups := range req.FallbackDNS {
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
}
for _, ups := range req.PrivateUpstreams {
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
}
wg.Wait()
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
cv.check()
cv.close()
result := map[string]string{}
m.Range(func(k, v any) bool {
// TODO(e.burkov): The upstreams used for both common and private
// resolving should be reported separately.
ups := k.(string)
status := v.(string)
for _, res := range cv.general {
if res.err != nil {
result[res.original] = res.err.Error()
} else {
result[res.original] = "OK"
}
}
result[ups] = status
for _, res := range cv.fallback {
if res.err != nil {
result[res.original] = res.err.Error()
} else {
result[res.original] = "OK"
}
}
return true
})
for _, res := range cv.private {
if res.err != nil {
result[res.original] = res.err.Error()
} else {
result[res.original] = "OK"
}
}
aghhttp.WriteJSONResponseOK(w, r, result)
}

View File

@ -627,7 +627,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
},
wantResp: map[string]any{
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
"[/domain.example/]/]1.2.3.4": `separating upstream line: ` +
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
`duplicated separator`,
},

View File

@ -21,6 +21,28 @@ import (
"golang.org/x/exp/slices"
)
const (
// errNotDomainSpecific is returned when the upstream should be
// domain-specific, but isn't.
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
// errMissingSeparator is returned when the domain-specific part of the
// upstream configuration line isn't closed.
errMissingSeparator errors.Error = "missing separator"
// errDupSeparator is returned when the domain-specific part of the upstream
// configuration line contains more than one ending separator.
errDupSeparator errors.Error = "duplicated separator"
// errNoDefaultUpstreams is returned when there are no default upstreams
// specified in the upstream configuration.
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
// errWrongResponse is returned when the checked upstream replies in an
// unexpected way.
errWrongResponse errors.Error = "wrong response"
)
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (s *Server) loadUpstreams() (upstreams []string, err error) {
@ -206,7 +228,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
// Don't wrap the error since it's informative enough as is.
return nil, err
} else if len(conf.Upstreams) == 0 {
return nil, errors.Error("no default upstreams specified")
return nil, errNoDefaultUpstreams
}
return conf, nil
@ -286,15 +308,7 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
}
// protocols are the supported URL schemes for upstreams.
var protocols = []string{
"h3://",
"https://",
"quic://",
"sdns://",
"tcp://",
"tls://",
"udp://",
}
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
@ -345,9 +359,9 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
case 2:
// Go on.
case 1:
return nil, nil, errors.Error("missing separator")
return nil, nil, errMissingSeparator
default:
return nil, nil, errors.Error("duplicated separator")
return nil, nil, errDupSeparator
}
for i, host := range strings.Split(parts[0], "/") {
@ -366,12 +380,199 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
return strings.Fields(parts[1]), domains, nil
}
// healthCheckFunc is a signature of function to check if upstream exchanges
// properly.
type healthCheckFunc func(u upstream.Upstream) (err error)
// upsConfValidator parses the [*proxy.UpstreamConfig] and checks the
// actual DNS availability of each upstream.
type upsConfValidator struct {
// general is the general upstream configuration.
general []*upsResult
// fallback is the fallback upstream configuration.
fallback []*upsResult
// private is the private upstream configuration.
private []*upsResult
}
// upsResult is a result of parsing and validation of an [upstream.Upstream].
type upsResult struct {
// server is the parsed upstream. It is nil when there was an error during
// parsing.
server upstream.Upstream
// err is the error either from parsing or from checking the upstream.
err error
// original is the piece of configuration that was or wasn't parsed.
original string
// isSpecific is true if the upstream is domain-specific.
isSpecific bool
}
// compare compares two [upsResult]s. It returns 0 if they are equal, -1 if ur
// should be sorted before other, and 1 otherwise.
//
// TODO(e.burkov): Improve.
func (ur *upsResult) compare(other *upsResult) (res int) {
return strings.Compare(ur.original, other.original)
}
// newUpstreamConfigValidator parses the upstream configuration and returns a
// validator for it. cv already contains the parsed upstreams along with errors
// related.
func newUpstreamConfigValidator(
general []string,
fallback []string,
private []string,
opts *upstream.Options,
) (cv *upsConfValidator) {
cv = &upsConfValidator{}
for _, line := range general {
cv.parseLine(&cv.general, line, opts)
}
for _, line := range fallback {
cv.parseLine(&cv.fallback, line, opts)
}
for _, line := range private {
cv.parseLine(&cv.private, line, opts)
}
return cv
}
// parseLine parses line and inserts the result into s. It can insert multiple
// results as well as none.
func (cv *upsConfValidator) parseLine(s *[]*upsResult, line string, opts *upstream.Options) {
upstreams, domains, err := separateUpstream(line)
if err != nil {
cv.insert(s, &upsResult{
err: fmt.Errorf("separating upstream line: %w", err),
original: line,
})
return
}
specific := len(domains) > 0
for _, upstreamAddr := range upstreams {
r := cv.parseUpstream(upstreamAddr, specific, opts)
if r != nil {
cv.insert(s, r)
}
}
}
// insert inserts r into slice in a sorted order, except duplicates. slice must
// not be nil.
func (cv *upsConfValidator) insert(slice *[]*upsResult, r *upsResult) {
i, has := slices.BinarySearchFunc(*slice, r, (*upsResult).compare)
if has {
log.Debug("dnsforward: duplicate configuration %q", r.original)
} else {
*slice = slices.Insert(*slice, i, r)
}
}
func (cv *upsConfValidator) close() {
for _, slice := range [][]*upsResult{cv.general, cv.fallback, cv.private} {
for _, r := range slice {
if r.server != nil {
r.err = errors.WithDeferred(r.err, r.server.Close())
}
}
}
}
// parseUpstream parses urlStr and returns the result of parsing. It returns
// nil if the specified server is domain-specific and points at the default
// upstream server which is validated separately.
func (cv *upsConfValidator) parseUpstream(
urlStr string,
specific bool,
opts *upstream.Options,
) (res *upsResult) {
if urlStr == "#" {
if specific {
return nil
}
return &upsResult{
err: errNotDomainSpecific,
original: urlStr,
isSpecific: specific,
}
}
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
if proto, _, ok := strings.Cut(urlStr, "://"); ok {
if !slices.Contains(protocols, proto) {
return &upsResult{
err: fmt.Errorf("bad protocol %q", proto),
original: urlStr,
isSpecific: specific,
}
}
}
ups, err := upstream.AddressToUpstream(urlStr, opts)
return &upsResult{
server: ups,
err: err,
original: urlStr,
isSpecific: specific,
}
}
// check tries to exchange with each successfully parsed upstream and enriches
// the results with the healthcheck errors.
func (cv *upsConfValidator) check() {
wg := &sync.WaitGroup{}
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
for _, res := range cv.general {
go cv.checkCommon(res, wg)
}
for _, res := range cv.fallback {
go cv.checkCommon(res, wg)
}
for _, res := range cv.private {
go cv.checkPrivate(res, wg)
}
wg.Wait()
}
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
error
}
// Error implements the [error] interface for domainSpecificTestError.
func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.error)
}
// checkCommon checks the DNS upstream for common external DNS exchanges. It's
// applicable to both general and fallback upstreams.
func (cv *upsConfValidator) checkCommon(res *upsResult, wg *sync.WaitGroup) {
defer wg.Done()
if res.server == nil {
return
}
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
// testTLD is the special-use fully-qualified domain name for testing the
// DNS server reachability.
//
@ -390,22 +591,27 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
}},
}
var reply *dns.Msg
reply, err = u.Exchange(req)
reply, err := res.server.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
res.err = fmt.Errorf("couldn't communicate with upstream: %w", err)
} else if len(reply.Answer) != 0 {
return errors.Error("wrong response")
res.err = errWrongResponse
}
return nil
if res.err != nil && res.isSpecific {
res.err = domainSpecificTestError{error: res.err}
}
}
// checkPrivateUpstreamExc checks if the upstream for resolving private
// addresses exchanges correctly.
//
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
// checkPrivate checks the DNS upstream for private address resolution. It's
// applicable to private upstreams.
func (cv *upsConfValidator) checkPrivate(res *upsResult, wg *sync.WaitGroup) {
defer wg.Done()
if res.server == nil {
return
}
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
// address resolution.
//
@ -424,108 +630,11 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
}},
}
if _, err = u.Exchange(req); err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
}
return nil
}
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
//
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
error
}
// Error implements the [error] interface for domainSpecificTestError.
func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.error)
}
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
// upstreams are exchanging correctly. It saves the result into a sync.Map
// where key is an upstream address and value is "OK", if the upstream
// exchanges correctly, or text of the error. It is intended to be used as a
// goroutine.
//
// TODO(s.chzhen): Separate to a different structure/file.
func (s *Server) checkDNS(
line string,
opts *upstream.Options,
check healthCheckFunc,
wg *sync.WaitGroup,
m *sync.Map,
) {
defer wg.Done()
defer log.OnPanic("dnsforward: checking upstreams")
upstreams, domains, err := separateUpstream(line)
_, err := res.server.Exchange(req)
if err != nil {
err = fmt.Errorf("wrong upstream format: %w", err)
m.Store(line, err.Error())
return
}
specific := len(domains) > 0
for _, upstreamAddr := range upstreams {
var useDefault bool
useDefault, err = validateUpstream(upstreamAddr, domains)
if err != nil {
err = fmt.Errorf("wrong upstream format: %w", err)
m.Store(upstreamAddr, err.Error())
continue
}
if useDefault {
continue
}
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
if err != nil {
m.Store(upstreamAddr, err.Error())
} else {
m.Store(upstreamAddr, "OK")
res.err = fmt.Errorf("couldn't communicate with upstream: %w", err)
if res.isSpecific {
res.err = domainSpecificTestError{error: res.err}
}
}
}
// checkUpstreamAddr creates the DNS upstream using opts and information from
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
// returns an error if addr is not valid DNS upstream address or the upstream
// is not exchanging correctly.
func (s *Server) checkUpstreamAddr(
addr string,
specific bool,
opts *upstream.Options,
check healthCheckFunc,
) (err error) {
defer func() {
if err != nil && specific {
err = domainSpecificTestError{error: err}
}
}()
opts = &upstream.Options{
Bootstrap: opts.Bootstrap,
Timeout: opts.Timeout,
PreferIPv6: opts.PreferIPv6,
}
u, err := upstream.AddressToUpstream(addr, opts)
if err != nil {
return fmt.Errorf("creating upstream for %q: %w", addr, err)
}
defer func() { err = errors.WithDeferred(err, u.Close()) }()
return check(u)
}