370 lines
11 KiB
Go
370 lines
11 KiB
Go
package dnsforward
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
"github.com/AdguardTeam/golibs/log"
|
|
"github.com/AdguardTeam/golibs/netutil"
|
|
"github.com/AdguardTeam/golibs/stringutil"
|
|
"golang.org/x/exp/maps"
|
|
"golang.org/x/exp/slices"
|
|
)
|
|
|
|
const (
|
|
// errNotDomainSpecific is returned when the upstream should be
|
|
// domain-specific, but isn't.
|
|
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
|
|
|
// errMissingSeparator is returned when the domain-specific part of the
|
|
// upstream configuration line isn't closed.
|
|
errMissingSeparator errors.Error = "missing separator"
|
|
|
|
// errDupSeparator is returned when the domain-specific part of the upstream
|
|
// configuration line contains more than one ending separator.
|
|
errDupSeparator errors.Error = "duplicated separator"
|
|
|
|
// errNoDefaultUpstreams is returned when there are no default upstreams
|
|
// specified in the upstream configuration.
|
|
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
|
|
|
// errWrongResponse is returned when the checked upstream replies in an
|
|
// unexpected way.
|
|
errWrongResponse errors.Error = "wrong response"
|
|
)
|
|
|
|
// loadUpstreams parses upstream DNS servers from the configured file or from
|
|
// the configuration itself.
|
|
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
|
if s.conf.UpstreamDNSFileName == "" {
|
|
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
|
|
}
|
|
|
|
var data []byte
|
|
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
|
}
|
|
|
|
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
|
|
|
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
|
|
|
|
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
|
}
|
|
|
|
// prepareUpstreamSettings sets upstream DNS server settings.
|
|
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
|
|
// Load upstreams either from the file, or from the settings
|
|
var upstreams []string
|
|
upstreams, err = s.loadUpstreams()
|
|
if err != nil {
|
|
return fmt.Errorf("loading upstreams: %w", err)
|
|
}
|
|
|
|
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
|
Bootstrap: boot,
|
|
Timeout: s.conf.UpstreamTimeout,
|
|
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
|
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
|
// Use a customized set of RootCAs, because Go's default mechanism of
|
|
// loading TLS roots does not always work properly on some routers so we're
|
|
// loading roots manually and pass it here.
|
|
//
|
|
// See [aghtls.SystemRootCAs].
|
|
//
|
|
// TODO(a.garipov): Investigate if that's true.
|
|
RootCAs: s.conf.TLSv12Roots,
|
|
CipherSuites: s.conf.TLSCiphers,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("preparing upstream config: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// prepareUpstreamConfig returns the upstream configuration based on upstreams
|
|
// and configuration of s.
|
|
func (s *Server) prepareUpstreamConfig(
|
|
upstreams []string,
|
|
defaultUpstreams []string,
|
|
opts *upstream.Options,
|
|
) (uc *proxy.UpstreamConfig, err error) {
|
|
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing upstream config: %w", err)
|
|
}
|
|
|
|
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
|
|
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
|
|
var defaultUpstreamConfig *proxy.UpstreamConfig
|
|
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing default upstreams: %w", err)
|
|
}
|
|
|
|
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
|
}
|
|
|
|
return uc, nil
|
|
}
|
|
|
|
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
|
// depending on configuration.
|
|
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
|
if !http3 {
|
|
return upstream.DefaultHTTPVersions
|
|
}
|
|
|
|
return []upstream.HTTPVersion{
|
|
upstream.HTTPVersion3,
|
|
upstream.HTTPVersion2,
|
|
upstream.HTTPVersion11,
|
|
}
|
|
}
|
|
|
|
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
|
// based on provided parameters.
|
|
func setProxyUpstreamMode(
|
|
conf *proxy.Config,
|
|
allServers bool,
|
|
fastestAddr bool,
|
|
fastestTimeout time.Duration,
|
|
) {
|
|
if allServers {
|
|
conf.UpstreamMode = proxy.UModeParallel
|
|
} else if fastestAddr {
|
|
conf.UpstreamMode = proxy.UModeFastestAddr
|
|
conf.FastestPingTimeout = fastestTimeout
|
|
} else {
|
|
conf.UpstreamMode = proxy.UModeLoadBalance
|
|
}
|
|
}
|
|
|
|
// createBootstrap returns a bootstrap resolver based on the configuration of s.
|
|
// boots are the upstream resolvers that should be closed after use. r is the
|
|
// actual bootstrap resolver, which may include the system hosts.
|
|
//
|
|
// TODO(e.burkov): This function currently returns a resolver and a slice of
|
|
// the upstream resolvers, which are essentially the same. boots are returned
|
|
// for being able to close them afterwards, but it introduces an implicit
|
|
// contract that r could only be used before that. Anyway, this code should
|
|
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
|
|
// and be used here.
|
|
func (s *Server) createBootstrap(
|
|
addrs []string,
|
|
opts *upstream.Options,
|
|
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
|
|
if len(addrs) == 0 {
|
|
addrs = defaultBootstrap
|
|
}
|
|
|
|
boots, err = aghnet.ParseBootstraps(addrs, opts)
|
|
if err != nil {
|
|
// Don't wrap the error, since it's informative enough as is.
|
|
return nil, nil, err
|
|
}
|
|
|
|
var parallel upstream.ParallelResolver
|
|
for _, b := range boots {
|
|
parallel = append(parallel, b)
|
|
}
|
|
|
|
if s.etcHosts != nil {
|
|
r = upstream.ConsequentResolver{s.etcHosts, parallel}
|
|
} else {
|
|
r = parallel
|
|
}
|
|
|
|
return r, boots, nil
|
|
}
|
|
|
|
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
|
|
// This function is useful for filtering out non-upstream lines from upstream
|
|
// configs.
|
|
func IsCommentOrEmpty(s string) (ok bool) {
|
|
return len(s) == 0 || s[0] == '#'
|
|
}
|
|
|
|
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
|
// configuration or nil if it can't be built.
|
|
//
|
|
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
|
// slice already so that this function may be considered useless.
|
|
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
|
// No need to validate comments and empty lines.
|
|
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
|
if len(upstreams) == 0 {
|
|
// Consider this case valid since it means the default server should be
|
|
// used.
|
|
return nil, nil
|
|
}
|
|
|
|
err = validateUpstreamConfig(upstreams)
|
|
if err != nil {
|
|
// Don't wrap the error since it's informative enough as is.
|
|
return nil, err
|
|
}
|
|
|
|
conf, err = proxy.ParseUpstreamsConfig(
|
|
upstreams,
|
|
&upstream.Options{
|
|
Bootstrap: net.DefaultResolver,
|
|
Timeout: DefaultTimeout,
|
|
},
|
|
)
|
|
if err != nil {
|
|
// Don't wrap the error since it's informative enough as is.
|
|
return nil, err
|
|
} else if len(conf.Upstreams) == 0 {
|
|
return nil, errNoDefaultUpstreams
|
|
}
|
|
|
|
return conf, nil
|
|
}
|
|
|
|
// validateUpstreamConfig validates each upstream from the upstream
|
|
// configuration and returns an error if any upstream is invalid.
|
|
//
|
|
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
|
func validateUpstreamConfig(conf []string) (err error) {
|
|
for _, u := range conf {
|
|
var ups []string
|
|
var isSpecific bool
|
|
ups, isSpecific, err = splitUpstreamLine(u)
|
|
if err != nil {
|
|
// Don't wrap the error since it's informative enough as is.
|
|
return err
|
|
}
|
|
|
|
for _, addr := range ups {
|
|
_, err = validateUpstream(addr, isSpecific)
|
|
if err != nil {
|
|
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ValidateUpstreams validates each upstream and returns an error if any
|
|
// upstream is invalid or if there are no default upstreams specified.
|
|
//
|
|
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
|
func ValidateUpstreams(upstreams []string) (err error) {
|
|
_, err = newUpstreamConfig(upstreams)
|
|
|
|
return err
|
|
}
|
|
|
|
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
|
// upstream is invalid or if there are no default upstreams specified. It also
|
|
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
|
// a locally-served network. privateNets must not be nil.
|
|
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
|
conf, err := newUpstreamConfig(upstreams)
|
|
if err != nil {
|
|
return fmt.Errorf("creating config: %w", err)
|
|
}
|
|
|
|
if conf == nil {
|
|
return nil
|
|
}
|
|
|
|
keys := maps.Keys(conf.DomainReservedUpstreams)
|
|
slices.Sort(keys)
|
|
|
|
var errs []error
|
|
for _, domain := range keys {
|
|
var subnet netip.Prefix
|
|
subnet, err = extractARPASubnet(domain)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
|
|
continue
|
|
}
|
|
|
|
if !privateNets.Contains(subnet.Addr().AsSlice()) {
|
|
errs = append(
|
|
errs,
|
|
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
|
|
)
|
|
}
|
|
}
|
|
|
|
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
|
}
|
|
|
|
// protocols are the supported URL schemes for upstreams.
|
|
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
|
|
|
// validateUpstream returns an error if u alongside with domains is not a valid
|
|
// upstream configuration. useDefault is true if the upstream is
|
|
// domain-specific and is configured to point at the default upstream server
|
|
// which is validated separately. The upstream is considered domain-specific
|
|
// only if domains is at least not nil.
|
|
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
|
|
// The special server address '#' means that default server must be used.
|
|
if useDefault = u == "#" && isSpecific; useDefault {
|
|
return useDefault, nil
|
|
}
|
|
|
|
// Check if the upstream has a valid protocol prefix.
|
|
//
|
|
// TODO(e.burkov): Validate the domain name.
|
|
if proto, _, ok := strings.Cut(u, "://"); ok {
|
|
if !slices.Contains(protocols, proto) {
|
|
return false, fmt.Errorf("bad protocol %q", proto)
|
|
}
|
|
} else if _, err = netip.ParseAddr(u); err == nil {
|
|
return false, nil
|
|
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
|
return false, nil
|
|
}
|
|
|
|
return false, err
|
|
}
|
|
|
|
// splitUpstreamLine returns the upstreams and the specified domains. domains
|
|
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
|
// empty.
|
|
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
|
|
if !strings.HasPrefix(upstreamStr, "[/") {
|
|
return []string{upstreamStr}, false, nil
|
|
}
|
|
|
|
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
|
|
|
|
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
|
|
if !found {
|
|
return nil, false, errMissingSeparator
|
|
} else if strings.Contains(ups, "/]") {
|
|
return nil, false, errDupSeparator
|
|
}
|
|
|
|
for i, host := range strings.Split(doms, "/") {
|
|
if host == "" {
|
|
continue
|
|
}
|
|
|
|
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
|
if err != nil {
|
|
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
|
|
}
|
|
|
|
isSpecific = true
|
|
}
|
|
|
|
return strings.Fields(ups), isSpecific, nil
|
|
}
|