dnsforward: imp code

This commit is contained in:
Eugene Burkov 2023-11-22 15:48:05 +03:00
parent 9487f1fd62
commit 70775975ce
1 changed files with 98 additions and 86 deletions

View File

@ -237,7 +237,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
// validateUpstreamConfig validates each upstream from the upstream
// configuration and returns an error if any upstream is invalid.
//
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func validateUpstreamConfig(conf []string) (err error) {
for _, u := range conf {
var ups []string
@ -262,7 +262,7 @@ func validateUpstreamConfig(conf []string) (err error) {
// ValidateUpstreams validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified.
//
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func ValidateUpstreams(upstreams []string) (err error) {
_, err = newUpstreamConfig(upstreams)
@ -393,7 +393,8 @@ type upstreamResult struct {
// err is the error either from parsing or from checking the upstream.
err error
// original is the piece of configuration that was or wasn't parsed.
// original is the piece of configuration that have either been turned to an
// upstream or caused an error.
original string
// isSpecific is true if the upstream is domain-specific.
@ -403,7 +404,8 @@ type upstreamResult struct {
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
// if ur should be sorted before other, and 1 otherwise.
//
// TODO(e.burkov): Improve.
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
// the end.
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
return strings.Compare(ur.original, other.original)
}
@ -438,7 +440,7 @@ func (cv *upstreamConfigValidator) insertLineResults(
s []*upstreamResult,
line string,
opts *upstream.Options,
) (res []*upstreamResult) {
) (result []*upstreamResult) {
upstreams, isSpecific, err := splitUpstreamLine(line)
if err != nil {
return cv.insert(s, &upstreamResult{
@ -448,20 +450,20 @@ func (cv *upstreamConfigValidator) insertLineResults(
}
for _, upstreamAddr := range upstreams {
r := cv.parseUpstream(upstreamAddr, opts)
if r == nil {
if isSpecific {
continue
}
r = &upstreamResult{
var res *upstreamResult
if upstreamAddr != "#" {
res = cv.parseUpstream(upstreamAddr, opts)
} else if !isSpecific {
res = &upstreamResult{
err: errNotDomainSpecific,
original: upstreamAddr,
}
} else {
continue
}
r.isSpecific = isSpecific
s = cv.insert(s, r)
res.isSpecific = isSpecific
s = cv.insert(s, res)
}
return s
@ -480,14 +482,13 @@ func (cv *upstreamConfigValidator) insert(s []*upstreamResult, r *upstreamResult
return slices.Insert(s, i, r)
}
// parseUpstream parses urlStr and returns the result of parsing. It returns
// nil if the specified server points at the default upstream server which is
// parseUpstream parses addr and returns the result of parsing. It returns nil
// if the specified server points at the default upstream server which is
// validated separately.
func (cv *upstreamConfigValidator) parseUpstream(addr string, opts *upstream.Options) (r *upstreamResult) {
if addr == "#" {
return nil
}
func (cv *upstreamConfigValidator) parseUpstream(
addr string,
opts *upstream.Options,
) (r *upstreamResult) {
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
@ -514,22 +515,68 @@ func (cv *upstreamConfigValidator) parseUpstream(addr string, opts *upstream.Opt
// [upsConfValidator.close] method, since it makes no sense to check the closed
// upstreams.
func (cv *upstreamConfigValidator) check() {
const (
// testTLD is the special-use fully-qualified domain name for testing
// the DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
testTLD = "test."
// inAddrARPATLD is the special-use fully-qualified domain name for PTR
// IP address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
inAddrARPATLD = "in-addr.arpa."
)
commonChecker := &healthchecker{
hostname: testTLD,
qtype: dns.TypeA,
ansEmpty: true,
}
arpaChecker := &healthchecker{
hostname: inAddrARPATLD,
qtype: dns.TypePTR,
ansEmpty: false,
}
wg := &sync.WaitGroup{}
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
for _, res := range cv.general {
go cv.checkSrv(res, wg, exchangeTest)
go cv.checkSrv(res, wg, commonChecker)
}
for _, res := range cv.fallback {
go cv.checkSrv(res, wg, exchangeTest)
go cv.checkSrv(res, wg, commonChecker)
}
for _, res := range cv.private {
go cv.checkSrv(res, wg, exchangeARPATest)
go cv.checkSrv(res, wg, arpaChecker)
}
wg.Wait()
}
// checkSrv runs hc on the server from res, if any, and stores any occurred
// error in res. wg is always marked done in the end. It used to be called in
// a separate goroutine.
func (cv *upstreamConfigValidator) checkSrv(
res *upstreamResult,
wg *sync.WaitGroup,
hc *healthchecker,
) {
defer wg.Done()
if res.server == nil {
return
}
res.err = hc.check(res.server)
if res.err != nil && res.isSpecific {
res.err = domainSpecificTestError{Err: res.err}
}
}
// close closes all the upstreams that were successfully parsed. It enriches
// the results with deferred closing errors.
func (cv *upstreamConfigValidator) close() {
@ -577,55 +624,49 @@ func (cv *upstreamConfigValidator) status() (results map[string]string) {
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
// warnings (non-critical errors) is desired.
type domainSpecificTestError struct {
// error is the actual error occurred during healthcheck test.
error
// Err is the actual error occurred during healthcheck test.
Err error
}
// type check
var _ error = domainSpecificTestError{}
// Error implements the [error] interface for domainSpecificTestError.
func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.error)
return fmt.Sprintf("WARNING: %s", err.Err)
}
// checkSrv runs hc on the server from res, if any, and stores any occurred
// error in res. wg is always marked done in the end. It used to be called in
// a separate goroutine.
func (cv *upstreamConfigValidator) checkSrv(
res *upstreamResult,
wg *sync.WaitGroup,
hc healthcheckFunc,
) {
defer wg.Done()
// type check
var _ errors.Wrapper = domainSpecificTestError{}
if res.server == nil {
return
}
res.err = hc(res.server)
if res.err != nil && res.isSpecific {
res.err = domainSpecificTestError{error: res.err}
}
// Unwrap implements the [errors.Wrapper] interface for domainSpecificTestError.
func (err domainSpecificTestError) Unwrap() (wrapped error) {
return err.Err
}
// healthcheckFunc is a function that checks the health of an upstream.
type healthcheckFunc func(u upstream.Upstream) (err error)
// healthchecker checks the upstream's status by exchanging with it.
type healthchecker struct {
// hostname is the name of the host to put into healthcheck DNS request.
hostname string
// exchangeTest check the health of a common upstream server, either the general
// or the fallback one.
func exchangeTest(u upstream.Upstream) (err error) {
// testTLD is the special-use fully-qualified domain name for testing the
// DNS server reachability.
//
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
const testTLD = "test."
// qtype is the type of DNS request to use for healthcheck.
qtype uint16
// ansEmpty defines if the answer section within the response is expected to
// be empty.
ansEmpty bool
}
// check exchanges with u and validates the response.
func (h *healthchecker) check(u upstream.Upstream) (err error) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: testTLD,
Qtype: dns.TypeA,
Name: h.hostname,
Qtype: h.qtype,
Qclass: dns.ClassINET,
}},
}
@ -633,38 +674,9 @@ func exchangeTest(u upstream.Upstream) (err error) {
reply, err := u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
} else if len(reply.Answer) != 0 {
} else if h.ansEmpty && len(reply.Answer) > 0 {
return errWrongResponse
}
return nil
}
// exchangeARPATest check the health of an upstream server for private RDNS
// resolution.
func exchangeARPATest(u upstream.Upstream) (err error) {
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
// address resolution.
//
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
const inAddrArpaTLD = "in-addr.arpa."
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: inAddrArpaTLD,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
}
_, err = u.Exchange(req)
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
}
return nil
}