dnsforward: add ups checker
This commit is contained in:
parent
09db7d2a60
commit
e2612e0e6f
|
@ -6,7 +6,6 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
|
@ -578,34 +577,34 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
defer closeBoots(boots)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
m := &sync.Map{}
|
||||
|
||||
wg.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go s.checkDNS(ups, opts, checkDNSUpstreamExc, wg, m)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go s.checkDNS(ups, opts, checkPrivateUpstreamExc, wg, m)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cv := newUpstreamConfigValidator(req.Upstreams, req.FallbackDNS, req.PrivateUpstreams, opts)
|
||||
cv.check()
|
||||
cv.close()
|
||||
|
||||
result := map[string]string{}
|
||||
m.Range(func(k, v any) bool {
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
ups := k.(string)
|
||||
status := v.(string)
|
||||
for _, res := range cv.general {
|
||||
if res.err != nil {
|
||||
result[res.original] = res.err.Error()
|
||||
} else {
|
||||
result[res.original] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
result[ups] = status
|
||||
for _, res := range cv.fallback {
|
||||
if res.err != nil {
|
||||
result[res.original] = res.err.Error()
|
||||
} else {
|
||||
result[res.original] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
for _, res := range cv.private {
|
||||
if res.err != nil {
|
||||
result[res.original] = res.err.Error()
|
||||
} else {
|
||||
result[res.original] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
}
|
||||
|
|
|
@ -627,7 +627,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
|
||||
"[/domain.example/]/]1.2.3.4": `separating upstream line: ` +
|
||||
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
|
||||
`duplicated separator`,
|
||||
},
|
||||
|
|
|
@ -21,6 +21,28 @@ import (
|
|||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
// errNotDomainSpecific is returned when the upstream should be
|
||||
// domain-specific, but isn't.
|
||||
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
||||
|
||||
// errMissingSeparator is returned when the domain-specific part of the
|
||||
// upstream configuration line isn't closed.
|
||||
errMissingSeparator errors.Error = "missing separator"
|
||||
|
||||
// errDupSeparator is returned when the domain-specific part of the upstream
|
||||
// configuration line contains more than one ending separator.
|
||||
errDupSeparator errors.Error = "duplicated separator"
|
||||
|
||||
// errNoDefaultUpstreams is returned when there are no default upstreams
|
||||
// specified in the upstream configuration.
|
||||
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
||||
|
||||
// errWrongResponse is returned when the checked upstream replies in an
|
||||
// unexpected way.
|
||||
errWrongResponse errors.Error = "wrong response"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
|
@ -206,7 +228,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
|||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
return nil, errNoDefaultUpstreams
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
|
@ -286,15 +308,7 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
|||
}
|
||||
|
||||
// protocols are the supported URL schemes for upstreams.
|
||||
var protocols = []string{
|
||||
"h3://",
|
||||
"https://",
|
||||
"quic://",
|
||||
"sdns://",
|
||||
"tcp://",
|
||||
"tls://",
|
||||
"udp://",
|
||||
}
|
||||
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
|
@ -345,9 +359,9 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
|
|||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return nil, nil, errors.Error("missing separator")
|
||||
return nil, nil, errMissingSeparator
|
||||
default:
|
||||
return nil, nil, errors.Error("duplicated separator")
|
||||
return nil, nil, errDupSeparator
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
|
@ -366,12 +380,199 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
|
|||
return strings.Fields(parts[1]), domains, nil
|
||||
}
|
||||
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
// upsConfValidator parses the [*proxy.UpstreamConfig] and checks the
|
||||
// actual DNS availability of each upstream.
|
||||
type upsConfValidator struct {
|
||||
// general is the general upstream configuration.
|
||||
general []*upsResult
|
||||
|
||||
// fallback is the fallback upstream configuration.
|
||||
fallback []*upsResult
|
||||
|
||||
// private is the private upstream configuration.
|
||||
private []*upsResult
|
||||
}
|
||||
|
||||
// upsResult is a result of parsing and validation of an [upstream.Upstream].
|
||||
type upsResult struct {
|
||||
// server is the parsed upstream. It is nil when there was an error during
|
||||
// parsing.
|
||||
server upstream.Upstream
|
||||
|
||||
// err is the error either from parsing or from checking the upstream.
|
||||
err error
|
||||
|
||||
// original is the piece of configuration that was or wasn't parsed.
|
||||
original string
|
||||
|
||||
// isSpecific is true if the upstream is domain-specific.
|
||||
isSpecific bool
|
||||
}
|
||||
|
||||
// compare compares two [upsResult]s. It returns 0 if they are equal, -1 if ur
|
||||
// should be sorted before other, and 1 otherwise.
|
||||
//
|
||||
// TODO(e.burkov): Improve.
|
||||
func (ur *upsResult) compare(other *upsResult) (res int) {
|
||||
return strings.Compare(ur.original, other.original)
|
||||
}
|
||||
|
||||
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
||||
// validator for it. cv already contains the parsed upstreams along with errors
|
||||
// related.
|
||||
func newUpstreamConfigValidator(
|
||||
general []string,
|
||||
fallback []string,
|
||||
private []string,
|
||||
opts *upstream.Options,
|
||||
) (cv *upsConfValidator) {
|
||||
cv = &upsConfValidator{}
|
||||
|
||||
for _, line := range general {
|
||||
cv.parseLine(&cv.general, line, opts)
|
||||
}
|
||||
for _, line := range fallback {
|
||||
cv.parseLine(&cv.fallback, line, opts)
|
||||
}
|
||||
for _, line := range private {
|
||||
cv.parseLine(&cv.private, line, opts)
|
||||
}
|
||||
|
||||
return cv
|
||||
}
|
||||
|
||||
// parseLine parses line and inserts the result into s. It can insert multiple
|
||||
// results as well as none.
|
||||
func (cv *upsConfValidator) parseLine(s *[]*upsResult, line string, opts *upstream.Options) {
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
cv.insert(s, &upsResult{
|
||||
err: fmt.Errorf("separating upstream line: %w", err),
|
||||
original: line,
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
specific := len(domains) > 0
|
||||
for _, upstreamAddr := range upstreams {
|
||||
r := cv.parseUpstream(upstreamAddr, specific, opts)
|
||||
if r != nil {
|
||||
cv.insert(s, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts r into slice in a sorted order, except duplicates. slice must
|
||||
// not be nil.
|
||||
func (cv *upsConfValidator) insert(slice *[]*upsResult, r *upsResult) {
|
||||
i, has := slices.BinarySearchFunc(*slice, r, (*upsResult).compare)
|
||||
if has {
|
||||
log.Debug("dnsforward: duplicate configuration %q", r.original)
|
||||
} else {
|
||||
*slice = slices.Insert(*slice, i, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (cv *upsConfValidator) close() {
|
||||
for _, slice := range [][]*upsResult{cv.general, cv.fallback, cv.private} {
|
||||
for _, r := range slice {
|
||||
if r.server != nil {
|
||||
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseUpstream parses urlStr and returns the result of parsing. It returns
|
||||
// nil if the specified server is domain-specific and points at the default
|
||||
// upstream server which is validated separately.
|
||||
func (cv *upsConfValidator) parseUpstream(
|
||||
urlStr string,
|
||||
specific bool,
|
||||
opts *upstream.Options,
|
||||
) (res *upsResult) {
|
||||
if urlStr == "#" {
|
||||
if specific {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &upsResult{
|
||||
err: errNotDomainSpecific,
|
||||
original: urlStr,
|
||||
isSpecific: specific,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
//
|
||||
// TODO(e.burkov): Validate the domain name.
|
||||
if proto, _, ok := strings.Cut(urlStr, "://"); ok {
|
||||
if !slices.Contains(protocols, proto) {
|
||||
return &upsResult{
|
||||
err: fmt.Errorf("bad protocol %q", proto),
|
||||
original: urlStr,
|
||||
isSpecific: specific,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ups, err := upstream.AddressToUpstream(urlStr, opts)
|
||||
|
||||
return &upsResult{
|
||||
server: ups,
|
||||
err: err,
|
||||
original: urlStr,
|
||||
isSpecific: specific,
|
||||
}
|
||||
}
|
||||
|
||||
// check tries to exchange with each successfully parsed upstream and enriches
|
||||
// the results with the healthcheck errors.
|
||||
func (cv *upsConfValidator) check() {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
|
||||
|
||||
for _, res := range cv.general {
|
||||
go cv.checkCommon(res, wg)
|
||||
}
|
||||
|
||||
for _, res := range cv.fallback {
|
||||
go cv.checkCommon(res, wg)
|
||||
}
|
||||
|
||||
for _, res := range cv.private {
|
||||
go cv.checkPrivate(res, wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// Error implements the [error] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkCommon checks the DNS upstream for common external DNS exchanges. It's
|
||||
// applicable to both general and fallback upstreams.
|
||||
func (cv *upsConfValidator) checkCommon(res *upsResult, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
if res.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
|
@ -390,22 +591,27 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
|||
}},
|
||||
}
|
||||
|
||||
var reply *dns.Msg
|
||||
reply, err = u.Exchange(req)
|
||||
reply, err := res.server.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
res.err = fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
res.err = errWrongResponse
|
||||
}
|
||||
|
||||
return nil
|
||||
if res.err != nil && res.isSpecific {
|
||||
res.err = domainSpecificTestError{error: res.err}
|
||||
}
|
||||
}
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// checkPrivate checks the DNS upstream for private address resolution. It's
|
||||
// applicable to private upstreams.
|
||||
func (cv *upsConfValidator) checkPrivate(res *upsResult, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
if res.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
|
@ -424,108 +630,11 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
|||
}},
|
||||
}
|
||||
|
||||
if _, err = u.Exchange(req); err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// Error implements the [error] interface for domainSpecificTestError.
|
||||
func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS parses line, creates DNS upstreams using opts, and checks if the
|
||||
// upstreams are exchanging correctly. It saves the result into a sync.Map
|
||||
// where key is an upstream address and value is "OK", if the upstream
|
||||
// exchanges correctly, or text of the error. It is intended to be used as a
|
||||
// goroutine.
|
||||
//
|
||||
// TODO(s.chzhen): Separate to a different structure/file.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
wg *sync.WaitGroup,
|
||||
m *sync.Map,
|
||||
) {
|
||||
defer wg.Done()
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
upstreams, domains, err := separateUpstream(line)
|
||||
_, err := res.server.Exchange(req)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(line, err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
specific := len(domains) > 0
|
||||
|
||||
for _, upstreamAddr := range upstreams {
|
||||
var useDefault bool
|
||||
useDefault, err = validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("wrong upstream format: %w", err)
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
err = s.checkUpstreamAddr(upstreamAddr, specific, opts, check)
|
||||
if err != nil {
|
||||
m.Store(upstreamAddr, err.Error())
|
||||
} else {
|
||||
m.Store(upstreamAddr, "OK")
|
||||
res.err = fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
if res.isSpecific {
|
||||
res.err = domainSpecificTestError{error: res.err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the DNS upstream using opts and information from
|
||||
// [s.dnsFilter.EtcHosts]. Checks if the DNS upstream exchanges correctly. It
|
||||
// returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
}
|
||||
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue