diff --git a/dnsforward/config.go b/dnsforward/config.go new file mode 100644 index 00000000..5679a834 --- /dev/null +++ b/dnsforward/config.go @@ -0,0 +1,270 @@ +package dnsforward + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "net/http" + "sort" + + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" +) + +// FilteringConfig represents the DNS filtering configuration of AdGuard Home +// The zero FilteringConfig is empty and ready for use. +type FilteringConfig struct { + // Callbacks for other modules + // -- + + // Filtering callback function + FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` + + // This callback function returns the list of upstream servers for a client specified by IP address + GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"` + + // Protection configuration + // -- + + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features + BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests + BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request + BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request + BlockingIPAddrv4 net.IP `yaml:"-"` + BlockingIPAddrv6 net.IP `yaml:"-"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + + // IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing + ParentalBlockHost string `yaml:"parental_block_host"` + SafeBrowsingBlockHost string `yaml:"safebrowsing_block_host"` + + // Anti-DNS amplification + // -- + + Ratelimit uint32 `yaml:"ratelimit"` // max number of requests per second from a given IP (0 to disable) + RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` // a list of whitelisted client IP addresses + RefuseAny bool `yaml:"refuse_any"` // if true, refuse ANY requests + + // Upstream DNS servers configuration + // -- + + UpstreamDNS []string `yaml:"upstream_dns"` + BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only) + AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled + FastestAddr bool `yaml:"fastest_addr"` // use Fastest Address algorithm + + // ECS settings + // -- + + EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option + EnableDNSSEC bool `yaml:"enable_dnssec"` // Set DNSSEC flag in outcoming DNS request + + // Access settings + // -- + + AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients + DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked + BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked + + // DNS cache settings + // -- + + CacheSize uint32 `yaml:"cache_size"` // DNS cache size (in bytes) + CacheMinTTL uint32 `yaml:"cache_ttl_min"` // override TTL value (minimum) received from upstream server + CacheMaxTTL uint32 `yaml:"cache_ttl_max"` // override TTL value (maximum) received from upstream server + + // Other settings + // -- + + BogusNXDomain []string `yaml:"bogus_nxdomain"` // transform responses with these IP addresses to NXDOMAIN + AAAADisabled bool `yaml:"aaaa_disabled"` // Respond with an empty answer to all AAAA requests +} + +// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS +type TLSConfig struct { + TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"` + StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate + CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain + PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key + + CertificatePath string `yaml:"certificate_path" json:"certificate_path"` // certificate file name + PrivateKeyPath string `yaml:"private_key_path" json:"private_key_path"` // private key file name + + CertificateChainData []byte `yaml:"-" json:"-"` + PrivateKeyData []byte `yaml:"-" json:"-"` + + cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig + dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject +} + +// ServerConfig represents server configuration. +// The zero ServerConfig is empty and ready for use. +type ServerConfig struct { + UDPListenAddr *net.UDPAddr // UDP listen address + TCPListenAddr *net.TCPAddr // TCP listen address + Upstreams []upstream.Upstream // Configured upstreams + DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams + OnDNSRequest func(d *proxy.DNSContext) + + FilteringConfig + TLSConfig + TLSAllowUnencryptedDOH bool + + TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 + TLSCiphers []uint16 // list of TLS ciphers to use + + // Called when the configuration is changed by HTTP request + ConfigModified func() + + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) +} + +// if any of ServerConfig values are zero, then default values from below are used +var defaultValues = ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 53}, + TCPListenAddr: &net.TCPAddr{Port: 53}, + FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, +} + +// createProxyConfig creates and validates configuration for the main proxy +func (s *Server) createProxyConfig() (proxy.Config, error) { + proxyConfig := proxy.Config{ + UDPListenAddr: s.conf.UDPListenAddr, + TCPListenAddr: s.conf.TCPListenAddr, + Ratelimit: int(s.conf.Ratelimit), + RatelimitWhitelist: s.conf.RatelimitWhitelist, + RefuseAny: s.conf.RefuseAny, + CacheEnabled: true, + CacheSizeBytes: int(s.conf.CacheSize), + CacheMinTTL: s.conf.CacheMinTTL, + CacheMaxTTL: s.conf.CacheMaxTTL, + Upstreams: s.conf.Upstreams, + DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, + BeforeRequestHandler: s.beforeRequestHandler, + RequestHandler: s.handleDNSRequest, + AllServers: s.conf.AllServers, + EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, + FindFastestAddr: s.conf.FastestAddr, + } + + if len(s.conf.BogusNXDomain) > 0 { + for _, s := range s.conf.BogusNXDomain { + ip := net.ParseIP(s) + if ip == nil { + log.Error("Invalid bogus IP: %s", s) + } else { + proxyConfig.BogusNXDomain = append(proxyConfig.BogusNXDomain, ip) + } + } + } + + // TLS settings + err := s.prepareTLS(&proxyConfig) + if err != nil { + return proxyConfig, err + } + + // Validate proxy config + if len(proxyConfig.Upstreams) == 0 { + return proxyConfig, errors.New("no upstream servers configured") + } + + return proxyConfig, nil +} + +// initDefaultSettings initializes default settings if nothing +// is configured +func (s *Server) initDefaultSettings() { + if len(s.conf.UpstreamDNS) == 0 { + s.conf.UpstreamDNS = defaultDNS + } + if len(s.conf.BootstrapDNS) == 0 { + s.conf.BootstrapDNS = defaultBootstrap + } + if len(s.conf.ParentalBlockHost) == 0 { + s.conf.ParentalBlockHost = parentalBlockHost + } + if len(s.conf.SafeBrowsingBlockHost) == 0 { + s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost + } + if s.conf.UDPListenAddr == nil { + s.conf.UDPListenAddr = defaultValues.UDPListenAddr + } + if s.conf.TCPListenAddr == nil { + s.conf.TCPListenAddr = defaultValues.TCPListenAddr + } +} + +// prepareUpstreamSettings - prepares upstream DNS server settings +func (s *Server) prepareUpstreamSettings() error { + upstreamConfig, err := proxy.ParseUpstreamsConfig(s.conf.UpstreamDNS, s.conf.BootstrapDNS, DefaultTimeout) + if err != nil { + return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) + } + s.conf.Upstreams = upstreamConfig.Upstreams + s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams + return nil +} + +// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries +func (s *Server) prepareIntlProxy() { + intlProxyConfig := proxy.Config{ + CacheEnabled: true, + CacheSizeBytes: 4096, + Upstreams: s.conf.Upstreams, + DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, + } + s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} +} + +// prepareTLS - prepares TLS configuration for the DNS proxy +func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { + if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { + proxyConfig.TLSListenAddr = s.conf.TLSListenAddr + var err error + s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) + if err != nil { + return errorx.Decorate(err, "Failed to parse TLS keypair") + } + + if s.conf.StrictSNICheck { + x, err := x509.ParseCertificate(s.conf.cert.Certificate[0]) + if err != nil { + return errorx.Decorate(err, "x509.ParseCertificate(): %s", err) + } + if len(x.DNSNames) != 0 { + s.conf.dnsNames = x.DNSNames + log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames) + sort.Strings(s.conf.dnsNames) + } else { + s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName) + log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName) + } + } + + proxyConfig.TLSConfig = &tls.Config{ + GetCertificate: s.onGetCertificate, + MinVersion: tls.VersionTLS12, + } + } + upstream.RootCAs = s.conf.TLSv12Roots + upstream.CipherSuites = s.conf.TLSCiphers + return nil +} + +// Called by 'tls' package when Client Hello is received +// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. +func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { + if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { + log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName) + return nil, fmt.Errorf("invalid SNI") + } + return &s.conf.cert, nil +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 4f630afb..e8d8a6b0 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -1,14 +1,10 @@ package dnsforward import ( - "crypto/tls" - "crypto/x509" "fmt" "net" "net/http" "runtime" - "sort" - "strings" "sync" "time" @@ -16,9 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/utils" "github.com/joomcode/errorx" "github.com/miekg/dns" ) @@ -90,12 +84,6 @@ func (s *Server) Close() { s.Unlock() } -func stringArrayDup(a []string) []string { - a2 := make([]string, len(a)) - copy(a2, a) - return a2 -} - // WriteDiskConfig - write configuration func (s *Server) WriteDiskConfig(c *FilteringConfig) { s.RLock() @@ -110,100 +98,6 @@ func (s *Server) WriteDiskConfig(c *FilteringConfig) { s.RUnlock() } -// FilteringConfig represents the DNS filtering configuration of AdGuard Home -// The zero FilteringConfig is empty and ready for use. -type FilteringConfig struct { - // Filtering callback function - FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` - - // This callback function returns the list of upstream servers for a client specified by IP address - GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"` - - ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features - - BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests - BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request - BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request - BlockingIPAddrv4 net.IP `yaml:"-"` - BlockingIPAddrv6 net.IP `yaml:"-"` - - BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) - Ratelimit uint32 `yaml:"ratelimit"` // max number of requests per second from a given IP (0 to disable) - RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` // a list of whitelisted client IP addresses - RefuseAny bool `yaml:"refuse_any"` // if true, refuse ANY requests - BootstrapDNS []string `yaml:"bootstrap_dns"` // a list of bootstrap DNS for DoH and DoT (plain DNS only) - AllServers bool `yaml:"all_servers"` // if true, parallel queries to all configured upstream servers are enabled - - EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option - - EnableDNSSEC bool `yaml:"enable_dnssec"` // Set DNSSEC flag in outcoming DNS request - - // Respond with an empty answer to all AAAA requests - AAAADisabled bool `yaml:"aaaa_disabled"` - - FastestAddrAlgo bool `yaml:"fastest_addr"` // use Fastest Address algorithm - - AllowedClients []string `yaml:"allowed_clients"` // IP addresses of whitelist clients - DisallowedClients []string `yaml:"disallowed_clients"` // IP addresses of clients that should be blocked - BlockedHosts []string `yaml:"blocked_hosts"` // hosts that should be blocked - - // IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing - ParentalBlockHost string `yaml:"parental_block_host"` - SafeBrowsingBlockHost string `yaml:"safebrowsing_block_host"` - - CacheSize uint32 `yaml:"cache_size"` // DNS cache size (in bytes) - CacheMinTTL uint32 `yaml:"cache_ttl_min"` // override TTL value (minimum) received from upstream server - CacheMaxTTL uint32 `yaml:"cache_ttl_max"` // override TTL value (maximum) received from upstream server - UpstreamDNS []string `yaml:"upstream_dns"` -} - -// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS -type TLSConfig struct { - TLSListenAddr *net.TCPAddr `yaml:"-" json:"-"` - StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate - CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"` // PEM-encoded certificates chain - PrivateKey string `yaml:"private_key" json:"private_key"` // PEM-encoded private key - - CertificatePath string `yaml:"certificate_path" json:"certificate_path"` // certificate file name - PrivateKeyPath string `yaml:"private_key_path" json:"private_key_path"` // private key file name - - CertificateChainData []byte `yaml:"-" json:"-"` - PrivateKeyData []byte `yaml:"-" json:"-"` - - cert tls.Certificate // nolint(structcheck) - linter thinks that this field is unused, while TLSConfig is directly included into ServerConfig - dnsNames []string // nolint(structcheck) // DNS names from certificate (SAN) or CN value from Subject -} - -// ServerConfig represents server configuration. -// The zero ServerConfig is empty and ready for use. -type ServerConfig struct { - UDPListenAddr *net.UDPAddr // UDP listen address - TCPListenAddr *net.TCPAddr // TCP listen address - Upstreams []upstream.Upstream // Configured upstreams - DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams - OnDNSRequest func(d *proxy.DNSContext) - - FilteringConfig - TLSConfig - TLSAllowUnencryptedDOH bool - - TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 - TLSCiphers []uint16 // list of TLS ciphers to use - - // Called when the configuration is changed by HTTP request - ConfigModified func() - - // Register an HTTP handler - HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) -} - -// if any of ServerConfig values are zero, then default values from below are used -var defaultValues = ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 53}, - TCPListenAddr: &net.TCPAddr{Port: 53}, - FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, -} - // Resolve - get IP addresses by host name from an upstream server. // No request/response filtering is performed. // Query log and Stats are not updated. @@ -251,8 +145,9 @@ func (s *Server) startInternal() error { } // Prepare the object -// nolint(gocyclo) func (s *Server) Prepare(config *ServerConfig) error { + // 1. Initialize the server configuration + // -- if config != nil { s.conf = *config if s.conf.BlockingMode == "custom_ip" { @@ -264,157 +159,50 @@ func (s *Server) Prepare(config *ServerConfig) error { } } - if len(s.conf.UpstreamDNS) == 0 { - s.conf.UpstreamDNS = defaultDNS - } - if len(s.conf.BootstrapDNS) == 0 { - s.conf.BootstrapDNS = defaultBootstrap - } + // 2. Set default values in the case if nothing is configured + // -- + s.initDefaultSettings() - upstreamConfig, err := proxy.ParseUpstreamsConfig(s.conf.UpstreamDNS, s.conf.BootstrapDNS, DefaultTimeout) + // 3. Prepare DNS servers settings + // -- + err := s.prepareUpstreamSettings() if err != nil { - return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) - } - s.conf.Upstreams = upstreamConfig.Upstreams - s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams - - if len(s.conf.ParentalBlockHost) == 0 { - s.conf.ParentalBlockHost = parentalBlockHost - } - if len(s.conf.SafeBrowsingBlockHost) == 0 { - s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost - } - if s.conf.UDPListenAddr == nil { - s.conf.UDPListenAddr = defaultValues.UDPListenAddr - } - if s.conf.TCPListenAddr == nil { - s.conf.TCPListenAddr = defaultValues.TCPListenAddr + return err } - proxyConfig := proxy.Config{ - UDPListenAddr: s.conf.UDPListenAddr, - TCPListenAddr: s.conf.TCPListenAddr, - Ratelimit: int(s.conf.Ratelimit), - RatelimitWhitelist: s.conf.RatelimitWhitelist, - RefuseAny: s.conf.RefuseAny, - CacheEnabled: true, - CacheSizeBytes: int(s.conf.CacheSize), - CacheMinTTL: s.conf.CacheMinTTL, - CacheMaxTTL: s.conf.CacheMaxTTL, - Upstreams: s.conf.Upstreams, - DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, - BeforeRequestHandler: s.beforeRequestHandler, - RequestHandler: s.handleDNSRequest, - AllServers: s.conf.AllServers, - EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, - FindFastestAddr: s.conf.FastestAddrAlgo, + // 3. Create DNS proxy configuration + // -- + var proxyConfig proxy.Config + proxyConfig, err = s.createProxyConfig() + if err != nil { + return err } - intlProxyConfig := proxy.Config{ - CacheEnabled: true, - CacheSizeBytes: 4096, - Upstreams: s.conf.Upstreams, - DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, - } - s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} + // 4. Prepare a DNS proxy instance that we use for internal DNS queries + // -- + s.prepareIntlProxy() + // 5. Initialize DNS access module + // -- s.access = &accessCtx{} err = s.access.Init(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) if err != nil { return err } - if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { - proxyConfig.TLSListenAddr = s.conf.TLSListenAddr - s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) - if err != nil { - return errorx.Decorate(err, "Failed to parse TLS keypair") - } - - if s.conf.StrictSNICheck { - x, err := x509.ParseCertificate(s.conf.cert.Certificate[0]) - if err != nil { - return errorx.Decorate(err, "x509.ParseCertificate(): %s", err) - } - if len(x.DNSNames) != 0 { - s.conf.dnsNames = x.DNSNames - log.Debug("DNS: using DNS names from certificate's SAN: %v", x.DNSNames) - sort.Strings(s.conf.dnsNames) - } else { - s.conf.dnsNames = append(s.conf.dnsNames, x.Subject.CommonName) - log.Debug("DNS: using DNS name from certificate's CN: %s", x.Subject.CommonName) - } - } - - proxyConfig.TLSConfig = &tls.Config{ - GetCertificate: s.onGetCertificate, - MinVersion: tls.VersionTLS12, - } - } - upstream.RootCAs = s.conf.TLSv12Roots - upstream.CipherSuites = s.conf.TLSCiphers - - if len(proxyConfig.Upstreams) == 0 { - log.Fatal("len(proxyConfig.Upstreams) == 0") - } - + // 6. Register web handlers if necessary + // -- if !webRegistered && s.conf.HTTPRegister != nil { webRegistered = true s.registerHandlers() } - // Initialize and start the DNS proxy + // 7. Create the main DNS proxy instance + // -- s.dnsProxy = &proxy.Proxy{Config: proxyConfig} return nil } -// Find value in a sorted array -func findSorted(ar []string, val string) int { - i := sort.SearchStrings(ar, val) - if i == len(ar) || ar[i] != val { - return -1 - } - return i -} - -func isWildcard(host string) bool { - return len(host) >= 2 && - host[0] == '*' && host[1] == '.' -} - -// Return TRUE if host name matches a wildcard pattern -func matchDomainWildcard(host, wildcard string) bool { - return isWildcard(wildcard) && - strings.HasSuffix(host, wildcard[1:]) -} - -// Return TRUE if client's SNI value matches DNS names from certificate -func matchDNSName(dnsNames []string, sni string) bool { - if utils.IsValidHostname(sni) != nil { - return false - } - if findSorted(dnsNames, sni) != -1 { - return true - } - - for _, dn := range dnsNames { - if matchDomainWildcard(sni, dn) { - return true - } - } - return false -} - -// Called by 'tls' package when Client Hello is received -// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. -func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { - if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { - log.Info("DNS: TLS: unknown SNI in Client Hello: %s", ch.ServerName) - return nil, fmt.Errorf("invalid SNI") - } - return &s.conf.cert, nil -} - // Stop stops the DNS server func (s *Server) Stop() error { s.Lock() @@ -479,695 +267,3 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) } } - -// Get IP address from net.Addr object -// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: -// https://github.com/AdguardTeam/AdGuardHome/issues/1261 -func ipFromAddr(a net.Addr) string { - switch addr := a.(type) { - case *net.UDPAddr: - return addr.IP.String() - case *net.TCPAddr: - return addr.IP.String() - } - return "" -} - -func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := ipFromAddr(d.Addr) - if s.access.IsBlockedIP(ip) { - log.Tracef("Client IP %s is blocked by settings", ip) - return false, nil - } - - if len(d.Req.Question) == 1 { - host := strings.TrimSuffix(d.Req.Question[0].Name, ".") - if s.access.IsBlockedDomain(host) { - log.Tracef("Domain %s is blocked by settings", host) - return false, nil - } - } - - return true, nil -} - -// To transfer information between modules -type dnsContext struct { - srv *Server - proxyCtx *proxy.DNSContext - setts *dnsfilter.RequestFilteringSettings // filtering settings for this client - startTime time.Time - result *dnsfilter.Result - origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering - origQuestion dns.Question // question received from client. Set when Rewrites are used. - err error // error returned from the module - protectionEnabled bool // filtering is enabled, dnsfilter object is ready - responseFromUpstream bool // response is received from upstream servers - origReqDNSSEC bool // DNSSEC flag in the original request from user -} - -const ( - resultDone = iota // module has completed its job, continue - resultFinish // module has completed its job, exit normally - resultError // an error occurred, exit with an error -) - -// Perform initial checks; process WHOIS & rDNS -func processInitial(ctx *dnsContext) int { - s := ctx.srv - d := ctx.proxyCtx - if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { - _ = proxy.CheckDisabledAAAARequest(d, true) - return resultFinish - } - - if s.conf.OnDNSRequest != nil { - s.conf.OnDNSRequest(d) - } - - // disable Mozilla DoH - if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) && - d.Req.Question[0].Name == "use-application-dns.net." { - d.Res = s.genNXDomain(d.Req) - return resultFinish - } - - return resultDone -} - -// Apply filtering logic -func processFilteringBeforeRequest(ctx *dnsContext) int { - s := ctx.srv - d := ctx.proxyCtx - - s.RLock() - // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. - // This could happen after proxy server has been stopped, but its workers are not yet exited. - // - // A better approach is for proxy.Stop() to wait until all its workers exit, - // but this would require the Upstream interface to have Close() function - // (to prevent from hanging while waiting for unresponsive DNS server to respond). - - var err error - ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil - if ctx.protectionEnabled { - ctx.setts = s.getClientRequestFilteringSettings(d) - ctx.result, err = s.filterDNSRequest(ctx) - } - s.RUnlock() - - if err != nil { - ctx.err = err - return resultError - } - return resultDone -} - -// Pass request to upstream servers; process the response -func processUpstream(ctx *dnsContext) int { - s := ctx.srv - d := ctx.proxyCtx - if d.Res != nil { - return resultDone // response is already set - nothing to do - } - - if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { - clientIP := ipFromAddr(d.Addr) - upstreams := s.conf.GetUpstreamsByClient(clientIP) - if len(upstreams) > 0 { - log.Debug("Using custom upstreams for %s", clientIP) - d.Upstreams = upstreams - } - } - - if s.conf.EnableDNSSEC { - opt := d.Req.IsEdns0() - if opt == nil { - log.Debug("DNS: Adding OPT record with DNSSEC flag") - d.Req.SetEdns0(4096, true) - } else if !opt.Do() { - opt.SetDo(true) - } else { - ctx.origReqDNSSEC = true - } - } - - // request was not filtered so let it be processed further - err := s.dnsProxy.Resolve(d) - if err != nil { - ctx.err = err - return resultError - } - - ctx.responseFromUpstream = true - return resultDone -} - -// Process DNSSEC after response from upstream server -func processDNSSECAfterResponse(ctx *dnsContext) int { - d := ctx.proxyCtx - - if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers - !ctx.srv.conf.EnableDNSSEC { - return resultDone - } - - optResp := d.Res.IsEdns0() - if !ctx.origReqDNSSEC && optResp != nil && optResp.Do() { - return resultDone - } - - // Remove RRSIG records from response - // because there is no DO flag in the original request from client, - // but we have EnableDNSSEC set, so we have set DO flag ourselves, - // and now we have to clean up the DNS records our client didn't ask for. - - answers := []dns.RR{} - for _, a := range d.Res.Answer { - switch a.(type) { - case *dns.RRSIG: - log.Debug("Removing RRSIG record from response: %v", a) - default: - answers = append(answers, a) - } - } - d.Res.Answer = answers - - answers = []dns.RR{} - for _, a := range d.Res.Ns { - switch a.(type) { - case *dns.RRSIG: - log.Debug("Removing RRSIG record from response: %v", a) - default: - answers = append(answers, a) - } - } - d.Res.Ns = answers - - return resultDone -} - -// Apply filtering logic after we have received response from upstream servers -func processFilteringAfterResponse(ctx *dnsContext) int { - s := ctx.srv - d := ctx.proxyCtx - res := ctx.result - var err error - - switch res.Reason { - case dnsfilter.ReasonRewrite: - if len(res.CanonName) == 0 { - break - } - d.Req.Question[0] = ctx.origQuestion - d.Res.Question[0] = ctx.origQuestion - - if len(d.Res.Answer) != 0 { - answer := []dns.RR{} - answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) - answer = append(answer, d.Res.Answer...) // host -> IP - d.Res.Answer = answer - } - - case dnsfilter.NotFilteredWhiteList: - // nothing - - default: - if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for - !ctx.responseFromUpstream { // only check response if it's from an upstream server - break - } - origResp2 := d.Res - ctx.result, err = s.filterDNSResponse(ctx) - if err != nil { - ctx.err = err - return resultError - } - if ctx.result != nil { - ctx.origResp = origResp2 // matched by response - } else { - ctx.result = &dnsfilter.Result{} - } - } - - return resultDone -} - -// Write Stats data and logs -func processQueryLogsAndStats(ctx *dnsContext) int { - elapsed := time.Since(ctx.startTime) - s := ctx.srv - d := ctx.proxyCtx - - shouldLog := true - msg := d.Req - - // don't log ANY request if refuseAny is enabled - if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny { - shouldLog = false - } - - s.RLock() - // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. - // This can happen after proxy server has been stopped, but its workers haven't yet exited. - if shouldLog && s.queryLog != nil { - p := querylog.AddParams{ - Question: msg, - Answer: d.Res, - OrigAnswer: ctx.origResp, - Result: ctx.result, - Elapsed: elapsed, - ClientIP: getIP(d.Addr), - } - if d.Upstream != nil { - p.Upstream = d.Upstream.Address() - } - s.queryLog.Add(p) - } - - s.updateStats(d, elapsed, *ctx.result) - s.RUnlock() - - return resultDone -} - -// handleDNSRequest filters the incoming DNS requests and writes them to the query log -// nolint (gocyclo) -func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { - ctx := &dnsContext{srv: s, proxyCtx: d} - ctx.result = &dnsfilter.Result{} - ctx.startTime = time.Now() - - type modProcessFunc func(ctx *dnsContext) int - mods := []modProcessFunc{ - processInitial, - processFilteringBeforeRequest, - processUpstream, - processDNSSECAfterResponse, - processFilteringAfterResponse, - processQueryLogsAndStats, - } - for _, process := range mods { - r := process(ctx) - switch r { - case resultDone: - // continue: call the next filter - - case resultFinish: - return nil - - case resultError: - return ctx.err - } - } - - if d.Res != nil { - d.Res.Compress = true // some devices require DNS message compression - } - return nil -} - -// Get IP address from net.Addr -func getIP(addr net.Addr) net.IP { - switch addr := addr.(type) { - case *net.UDPAddr: - return addr.IP - case *net.TCPAddr: - return addr.IP - } - return nil -} - -func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) { - if s.stats == nil { - return - } - - e := stats.Entry{} - e.Domain = strings.ToLower(d.Req.Question[0].Name) - e.Domain = e.Domain[:len(e.Domain)-1] // remove last "." - switch addr := d.Addr.(type) { - case *net.UDPAddr: - e.Client = addr.IP - case *net.TCPAddr: - e.Client = addr.IP - } - e.Time = uint32(elapsed / 1000) - e.Result = stats.RNotFiltered - - switch res.Reason { - - case dnsfilter.FilteredSafeBrowsing: - e.Result = stats.RSafeBrowsing - - case dnsfilter.FilteredParental: - e.Result = stats.RParental - - case dnsfilter.FilteredSafeSearch: - e.Result = stats.RSafeSearch - - case dnsfilter.FilteredBlackList: - fallthrough - case dnsfilter.FilteredInvalid: - fallthrough - case dnsfilter.FilteredBlockedService: - e.Result = stats.RFiltered - } - - s.stats.Update(e) -} - -// getClientRequestFilteringSettings lookups client filtering settings -// using the client's IP address from the DNSContext -func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings { - setts := s.dnsFilter.GetConfig() - setts.FilteringEnabled = true - if s.conf.FilterHandler != nil { - clientAddr := ipFromAddr(d.Addr) - s.conf.FilterHandler(clientAddr, &setts) - } - return &setts -} - -// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered -func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { - d := ctx.proxyCtx - req := d.Req - host := strings.TrimSuffix(req.Question[0].Name, ".") - res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts) - if err != nil { - // Return immediately if there's an error - return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) - - } else if res.IsFiltered { - // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) - d.Res = s.genDNSFilterMessage(d, &res) - - } else if (res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts) && - len(res.IPList) != 0 { - resp := s.makeResponse(req) - - name := host - if len(res.CanonName) != 0 { - resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName)) - name = res.CanonName - } - - for _, ip := range res.IPList { - ip4 := ip.To4() - if req.Question[0].Qtype == dns.TypeA && ip4 != nil { - a := s.genAAnswer(req, ip4) - a.Hdr.Name = dns.Fqdn(name) - resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil { - a := s.genAAAAAnswer(req, ip) - a.Hdr.Name = dns.Fqdn(name) - resp.Answer = append(resp.Answer, a) - } - } - - d.Res = resp - - } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { - ctx.origQuestion = d.Req.Question[0] - // resolve canonical name, not the original host name - d.Req.Question[0].Name = dns.Fqdn(res.CanonName) - - } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { - - resp := s.makeResponse(req) - ptr := &dns.PTR{} - ptr.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - ptr.Ptr = res.ReverseHost - resp.Answer = append(resp.Answer, ptr) - d.Res = resp - } - - return &res, err -} - -// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. -// If this is a match, we set a new response in d.Res and return. -func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) { - d := ctx.proxyCtx - for _, a := range d.Res.Answer { - host := "" - - switch v := a.(type) { - case *dns.CNAME: - log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) - host = strings.TrimSuffix(v.Target, ".") - - case *dns.A: - host = v.A.String() - log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) - - case *dns.AAAA: - host = v.AAAA.String() - log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) - - default: - continue - } - - s.RLock() - // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. - // This could happen after proxy server has been stopped, but its workers are not yet exited. - if !s.conf.ProtectionEnabled || s.dnsFilter == nil { - s.RUnlock() - continue - } - res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts) - s.RUnlock() - - if err != nil { - return nil, err - - } else if res.IsFiltered { - d.Res = s.genDNSFilterMessage(d, &res) - log.Debug("DNSFwd: Matched %s by response: %s", d.Req.Question[0].Name, host) - return &res, nil - } - } - - return nil, nil -} - -// Create a DNS response by DNS request and set necessary flags -func (s *Server) makeResponse(req *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetReply(req) - resp.RecursionAvailable = true - resp.Compress = true - return &resp -} - -// genDNSFilterMessage generates a DNS message corresponding to the filtering result -func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { - m := d.Req - - if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { - return s.genNXDomain(m) - } - - switch result.Reason { - case dnsfilter.FilteredSafeBrowsing: - return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d) - case dnsfilter.FilteredParental: - return s.genBlockedHost(m, s.conf.ParentalBlockHost, d) - default: - // If the query was filtered by "Safe search", dnsfilter also must return - // the IP address that must be used in response. - // In this case regardless of the filtering method, we should return it - if result.Reason == dnsfilter.FilteredSafeSearch && result.IP != nil { - return s.genResponseWithIP(m, result.IP) - } - - if s.conf.BlockingMode == "null_ip" { - // it means that we should return 0.0.0.0 or :: for any blocked request - - switch m.Question[0].Qtype { - case dns.TypeA: - return s.genARecord(m, []byte{0, 0, 0, 0}) - case dns.TypeAAAA: - return s.genAAAARecord(m, net.IPv6zero) - } - - } else if s.conf.BlockingMode == "custom_ip" { - // means that we should return custom IP for any blocked request - - switch m.Question[0].Qtype { - case dns.TypeA: - return s.genARecord(m, s.conf.BlockingIPAddrv4) - case dns.TypeAAAA: - return s.genAAAARecord(m, s.conf.BlockingIPAddrv6) - } - - } else if s.conf.BlockingMode == "nxdomain" { - // means that we should return NXDOMAIN for any blocked request - - return s.genNXDomain(m) - } - - // Default blocking mode - // If there's an IP specified in the rule, return it - // If there is no IP, return NXDOMAIN - if result.IP != nil { - return s.genResponseWithIP(m, result.IP) - } - return s.genNXDomain(m) - } -} - -func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeServerFailure) - resp.RecursionAvailable = true - return &resp -} - -func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { - resp := s.makeResponse(request) - resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) - return resp -} - -func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { - resp := s.makeResponse(request) - resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) - return resp -} - -func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { - answer := new(dns.A) - answer.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeA, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - answer.A = ip - return answer -} - -func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA { - answer := new(dns.AAAA) - answer.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeAAAA, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - answer.AAAA = ip - return answer -} - -// generate DNS response message with an IP address -func (s *Server) genResponseWithIP(req *dns.Msg, ip net.IP) *dns.Msg { - if req.Question[0].Qtype == dns.TypeA && ip.To4() != nil { - return s.genARecord(req, ip.To4()) - } else if req.Question[0].Qtype == dns.TypeAAAA && - len(ip) == net.IPv6len && ip.To4() == nil { - return s.genAAAARecord(req, ip) - } - - // empty response - resp := s.makeResponse(req) - return resp -} - -func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { - - ip := net.ParseIP(newAddr) - if ip != nil { - return s.genResponseWithIP(request, ip) - } - - // look up the hostname, TODO: cache - replReq := dns.Msg{} - replReq.SetQuestion(dns.Fqdn(newAddr), request.Question[0].Qtype) - replReq.RecursionDesired = true - - newContext := &proxy.DNSContext{ - Proto: d.Proto, - Addr: d.Addr, - StartTime: time.Now(), - Req: &replReq, - } - - err := s.dnsProxy.Resolve(newContext) - if err != nil { - log.Printf("Couldn't look up replacement host '%s': %s", newAddr, err) - return s.genServerFailure(request) - } - - resp := s.makeResponse(request) - if newContext.Res != nil { - for _, answer := range newContext.Res.Answer { - answer.Header().Name = request.Question[0].Name - resp.Answer = append(resp.Answer, answer) - } - } - - return resp -} - -// Make a CNAME response -func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME { - answer := new(dns.CNAME) - answer.Hdr = dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypeCNAME, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - answer.Target = dns.Fqdn(cname) - return answer -} - -func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeNameError) - resp.RecursionAvailable = true - resp.Ns = s.genSOA(request) - return &resp -} - -func (s *Server) genSOA(request *dns.Msg) []dns.RR { - zone := "" - if len(request.Question) > 0 { - zone = request.Question[0].Name - } - - soa := dns.SOA{ - // values copied from verisign's nonexistent .com domain - // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers - Refresh: 1800, - Retry: 900, - Expire: 604800, - Minttl: 86400, - // copied from AdGuard DNS - Ns: "fake-for-negative-caching.adguard.com.", - Serial: 100500, - // rest is request-specific - Hdr: dns.RR_Header{ - Name: zone, - Rrtype: dns.TypeSOA, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - }, - Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." - } - if soa.Hdr.Ttl == 0 { - soa.Hdr.Ttl = defaultValues.BlockedResponseTTL - } - if len(zone) > 0 && zone[0] != '.' { - soa.Mbox += zone - } - return []dns.RR{&soa} -} diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go index 4658cbbf..a628bc74 100644 --- a/dnsforward/dnsforward_http.go +++ b/dnsforward/dnsforward_http.go @@ -51,7 +51,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { resp.EDNSCSEnabled = s.conf.EnableEDNSClientSubnet resp.DNSSECEnabled = s.conf.EnableDNSSEC resp.DisableIPv6 = s.conf.AAAADisabled - resp.FastestAddr = s.conf.FastestAddrAlgo + resp.FastestAddr = s.conf.FastestAddr resp.ParallelRequests = s.conf.AllServers s.RUnlock() @@ -170,7 +170,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { } if js.Exists("fastest_addr") { - s.conf.FastestAddrAlgo = req.FastestAddr + s.conf.FastestAddr = req.FastestAddr } if js.Exists("parallel_requests") { diff --git a/dnsforward/filter.go b/dnsforward/filter.go new file mode 100644 index 00000000..f6b9da63 --- /dev/null +++ b/dnsforward/filter.go @@ -0,0 +1,150 @@ +package dnsforward + +import ( + "strings" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/log" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { + ip := ipFromAddr(d.Addr) + if s.access.IsBlockedIP(ip) { + log.Tracef("Client IP %s is blocked by settings", ip) + return false, nil + } + + if len(d.Req.Question) == 1 { + host := strings.TrimSuffix(d.Req.Question[0].Name, ".") + if s.access.IsBlockedDomain(host) { + log.Tracef("Domain %s is blocked by settings", host) + return false, nil + } + } + + return true, nil +} + +// getClientRequestFilteringSettings lookups client filtering settings +// using the client's IP address from the DNSContext +func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings { + setts := s.dnsFilter.GetConfig() + setts.FilteringEnabled = true + if s.conf.FilterHandler != nil { + clientAddr := ipFromAddr(d.Addr) + s.conf.FilterHandler(clientAddr, &setts) + } + return &setts +} + +// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered +func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { + d := ctx.proxyCtx + req := d.Req + host := strings.TrimSuffix(req.Question[0].Name, ".") + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts) + if err != nil { + // Return immediately if there's an error + return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) + + } else if res.IsFiltered { + // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + d.Res = s.genDNSFilterMessage(d, &res) + + } else if (res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts) && + len(res.IPList) != 0 { + resp := s.makeResponse(req) + + name := host + if len(res.CanonName) != 0 { + resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName)) + name = res.CanonName + } + + for _, ip := range res.IPList { + ip4 := ip.To4() + if req.Question[0].Qtype == dns.TypeA && ip4 != nil { + a := s.genAAnswer(req, ip4) + a.Hdr.Name = dns.Fqdn(name) + resp.Answer = append(resp.Answer, a) + } else if req.Question[0].Qtype == dns.TypeAAAA && ip4 == nil { + a := s.genAAAAAnswer(req, ip) + a.Hdr.Name = dns.Fqdn(name) + resp.Answer = append(resp.Answer, a) + } + } + + d.Res = resp + + } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + ctx.origQuestion = d.Req.Question[0] + // resolve canonical name, not the original host name + d.Req.Question[0].Name = dns.Fqdn(res.CanonName) + + } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { + + resp := s.makeResponse(req) + ptr := &dns.PTR{} + ptr.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + ptr.Ptr = res.ReverseHost + resp.Answer = append(resp.Answer, ptr) + d.Res = resp + } + + return &res, err +} + +// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. +// If this is a match, we set a new response in d.Res and return. +func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) { + d := ctx.proxyCtx + for _, a := range d.Res.Answer { + host := "" + + switch v := a.(type) { + case *dns.CNAME: + log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) + host = strings.TrimSuffix(v.Target, ".") + + case *dns.A: + host = v.A.String() + log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) + + case *dns.AAAA: + host = v.AAAA.String() + log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) + + default: + continue + } + + s.RLock() + // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. + // This could happen after proxy server has been stopped, but its workers are not yet exited. + if !s.conf.ProtectionEnabled || s.dnsFilter == nil { + s.RUnlock() + continue + } + res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts) + s.RUnlock() + + if err != nil { + return nil, err + + } else if res.IsFiltered { + d.Res = s.genDNSFilterMessage(d, &res) + log.Debug("DNSFwd: Matched %s by response: %s", d.Req.Question[0].Name, host) + return &res, nil + } + } + + return nil, nil +} diff --git a/dnsforward/handle_dns.go b/dnsforward/handle_dns.go new file mode 100644 index 00000000..5bd663cf --- /dev/null +++ b/dnsforward/handle_dns.go @@ -0,0 +1,247 @@ +package dnsforward + +import ( + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// To transfer information between modules +type dnsContext struct { + srv *Server + proxyCtx *proxy.DNSContext + setts *dnsfilter.RequestFilteringSettings // filtering settings for this client + startTime time.Time + result *dnsfilter.Result + origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering + origQuestion dns.Question // question received from client. Set when Rewrites are used. + err error // error returned from the module + protectionEnabled bool // filtering is enabled, dnsfilter object is ready + responseFromUpstream bool // response is received from upstream servers + origReqDNSSEC bool // DNSSEC flag in the original request from user +} + +const ( + resultDone = iota // module has completed its job, continue + resultFinish // module has completed its job, exit normally + resultError // an error occurred, exit with an error +) + +// handleDNSRequest filters the incoming DNS requests and writes them to the query log +func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { + ctx := &dnsContext{srv: s, proxyCtx: d} + ctx.result = &dnsfilter.Result{} + ctx.startTime = time.Now() + + type modProcessFunc func(ctx *dnsContext) int + mods := []modProcessFunc{ + processInitial, + processFilteringBeforeRequest, + processUpstream, + processDNSSECAfterResponse, + processFilteringAfterResponse, + processQueryLogsAndStats, + } + for _, process := range mods { + r := process(ctx) + switch r { + case resultDone: + // continue: call the next filter + + case resultFinish: + return nil + + case resultError: + return ctx.err + } + } + + if d.Res != nil { + d.Res.Compress = true // some devices require DNS message compression + } + return nil +} + +// Perform initial checks; process WHOIS & rDNS +func processInitial(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { + _ = proxy.CheckDisabledAAAARequest(d, true) + return resultFinish + } + + if s.conf.OnDNSRequest != nil { + s.conf.OnDNSRequest(d) + } + + // disable Mozilla DoH + if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) && + d.Req.Question[0].Name == "use-application-dns.net." { + d.Res = s.genNXDomain(d.Req) + return resultFinish + } + + return resultDone +} + +// Apply filtering logic +func processFilteringBeforeRequest(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + + s.RLock() + // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. + // This could happen after proxy server has been stopped, but its workers are not yet exited. + // + // A better approach is for proxy.Stop() to wait until all its workers exit, + // but this would require the Upstream interface to have Close() function + // (to prevent from hanging while waiting for unresponsive DNS server to respond). + + var err error + ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil + if ctx.protectionEnabled { + ctx.setts = s.getClientRequestFilteringSettings(d) + ctx.result, err = s.filterDNSRequest(ctx) + } + s.RUnlock() + + if err != nil { + ctx.err = err + return resultError + } + return resultDone +} + +// Pass request to upstream servers; process the response +func processUpstream(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + if d.Res != nil { + return resultDone // response is already set - nothing to do + } + + if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { + clientIP := ipFromAddr(d.Addr) + upstreams := s.conf.GetUpstreamsByClient(clientIP) + if len(upstreams) > 0 { + log.Debug("Using custom upstreams for %s", clientIP) + d.Upstreams = upstreams + } + } + + if s.conf.EnableDNSSEC { + opt := d.Req.IsEdns0() + if opt == nil { + log.Debug("DNS: Adding OPT record with DNSSEC flag") + d.Req.SetEdns0(4096, true) + } else if !opt.Do() { + opt.SetDo(true) + } else { + ctx.origReqDNSSEC = true + } + } + + // request was not filtered so let it be processed further + err := s.dnsProxy.Resolve(d) + if err != nil { + ctx.err = err + return resultError + } + + ctx.responseFromUpstream = true + return resultDone +} + +// Process DNSSEC after response from upstream server +func processDNSSECAfterResponse(ctx *dnsContext) int { + d := ctx.proxyCtx + + if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers + !ctx.srv.conf.EnableDNSSEC { + return resultDone + } + + optResp := d.Res.IsEdns0() + if !ctx.origReqDNSSEC && optResp != nil && optResp.Do() { + return resultDone + } + + // Remove RRSIG records from response + // because there is no DO flag in the original request from client, + // but we have EnableDNSSEC set, so we have set DO flag ourselves, + // and now we have to clean up the DNS records our client didn't ask for. + + answers := []dns.RR{} + for _, a := range d.Res.Answer { + switch a.(type) { + case *dns.RRSIG: + log.Debug("Removing RRSIG record from response: %v", a) + default: + answers = append(answers, a) + } + } + d.Res.Answer = answers + + answers = []dns.RR{} + for _, a := range d.Res.Ns { + switch a.(type) { + case *dns.RRSIG: + log.Debug("Removing RRSIG record from response: %v", a) + default: + answers = append(answers, a) + } + } + d.Res.Ns = answers + + return resultDone +} + +// Apply filtering logic after we have received response from upstream servers +func processFilteringAfterResponse(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + res := ctx.result + var err error + + switch res.Reason { + case dnsfilter.ReasonRewrite: + if len(res.CanonName) == 0 { + break + } + d.Req.Question[0] = ctx.origQuestion + d.Res.Question[0] = ctx.origQuestion + + if len(d.Res.Answer) != 0 { + answer := []dns.RR{} + answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) + answer = append(answer, d.Res.Answer...) // host -> IP + d.Res.Answer = answer + } + + case dnsfilter.NotFilteredWhiteList: + // nothing + + default: + if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for + !ctx.responseFromUpstream { // only check response if it's from an upstream server + break + } + origResp2 := d.Res + ctx.result, err = s.filterDNSResponse(ctx) + if err != nil { + ctx.err = err + return resultError + } + if ctx.result != nil { + ctx.origResp = origResp2 // matched by response + } else { + ctx.result = &dnsfilter.Result{} + } + } + + return resultDone +} diff --git a/dnsforward/msg.go b/dnsforward/msg.go new file mode 100644 index 00000000..a1078539 --- /dev/null +++ b/dnsforward/msg.go @@ -0,0 +1,225 @@ +package dnsforward + +import ( + "log" + "net" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/miekg/dns" +) + +// Create a DNS response by DNS request and set necessary flags +func (s *Server) makeResponse(req *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetReply(req) + resp.RecursionAvailable = true + resp.Compress = true + return &resp +} + +// genDNSFilterMessage generates a DNS message corresponding to the filtering result +func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { + m := d.Req + + if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { + return s.genNXDomain(m) + } + + switch result.Reason { + case dnsfilter.FilteredSafeBrowsing: + return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d) + case dnsfilter.FilteredParental: + return s.genBlockedHost(m, s.conf.ParentalBlockHost, d) + default: + // If the query was filtered by "Safe search", dnsfilter also must return + // the IP address that must be used in response. + // In this case regardless of the filtering method, we should return it + if result.Reason == dnsfilter.FilteredSafeSearch && result.IP != nil { + return s.genResponseWithIP(m, result.IP) + } + + if s.conf.BlockingMode == "null_ip" { + // it means that we should return 0.0.0.0 or :: for any blocked request + + switch m.Question[0].Qtype { + case dns.TypeA: + return s.genARecord(m, []byte{0, 0, 0, 0}) + case dns.TypeAAAA: + return s.genAAAARecord(m, net.IPv6zero) + } + + } else if s.conf.BlockingMode == "custom_ip" { + // means that we should return custom IP for any blocked request + + switch m.Question[0].Qtype { + case dns.TypeA: + return s.genARecord(m, s.conf.BlockingIPAddrv4) + case dns.TypeAAAA: + return s.genAAAARecord(m, s.conf.BlockingIPAddrv6) + } + + } else if s.conf.BlockingMode == "nxdomain" { + // means that we should return NXDOMAIN for any blocked request + + return s.genNXDomain(m) + } + + // Default blocking mode + // If there's an IP specified in the rule, return it + // If there is no IP, return NXDOMAIN + if result.IP != nil { + return s.genResponseWithIP(m, result.IP) + } + return s.genNXDomain(m) + } +} + +func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeServerFailure) + resp.RecursionAvailable = true + return &resp +} + +func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { + resp := s.makeResponse(request) + resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) + return resp +} + +func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { + resp := s.makeResponse(request) + resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) + return resp +} + +func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.A = ip + return answer +} + +func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA { + answer := new(dns.AAAA) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeAAAA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.AAAA = ip + return answer +} + +// generate DNS response message with an IP address +func (s *Server) genResponseWithIP(req *dns.Msg, ip net.IP) *dns.Msg { + if req.Question[0].Qtype == dns.TypeA && ip.To4() != nil { + return s.genARecord(req, ip.To4()) + } else if req.Question[0].Qtype == dns.TypeAAAA && + len(ip) == net.IPv6len && ip.To4() == nil { + return s.genAAAARecord(req, ip) + } + + // empty response + resp := s.makeResponse(req) + return resp +} + +func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { + + ip := net.ParseIP(newAddr) + if ip != nil { + return s.genResponseWithIP(request, ip) + } + + // look up the hostname, TODO: cache + replReq := dns.Msg{} + replReq.SetQuestion(dns.Fqdn(newAddr), request.Question[0].Qtype) + replReq.RecursionDesired = true + + newContext := &proxy.DNSContext{ + Proto: d.Proto, + Addr: d.Addr, + StartTime: time.Now(), + Req: &replReq, + } + + err := s.dnsProxy.Resolve(newContext) + if err != nil { + log.Printf("Couldn't look up replacement host '%s': %s", newAddr, err) + return s.genServerFailure(request) + } + + resp := s.makeResponse(request) + if newContext.Res != nil { + for _, answer := range newContext.Res.Answer { + answer.Header().Name = request.Question[0].Name + resp.Answer = append(resp.Answer, answer) + } + } + + return resp +} + +// Make a CNAME response +func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME { + answer := new(dns.CNAME) + answer.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypeCNAME, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + answer.Target = dns.Fqdn(cname) + return answer +} + +func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNameError) + resp.RecursionAvailable = true + resp.Ns = s.genSOA(request) + return &resp +} + +func (s *Server) genSOA(request *dns.Msg) []dns.RR { + zone := "" + if len(request.Question) > 0 { + zone = request.Question[0].Name + } + + soa := dns.SOA{ + // values copied from verisign's nonexistent .com domain + // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, + // copied from AdGuard DNS + Ns: "fake-for-negative-caching.adguard.com.", + Serial: 100500, + // rest is request-specific + Hdr: dns.RR_Header{ + Name: zone, + Rrtype: dns.TypeSOA, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + }, + Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." + } + if soa.Hdr.Ttl == 0 { + soa.Hdr.Ttl = defaultValues.BlockedResponseTTL + } + if len(zone) > 0 && zone[0] != '.' { + soa.Mbox += zone + } + return []dns.RR{&soa} +} diff --git a/dnsforward/stats.go b/dnsforward/stats.go new file mode 100644 index 00000000..e9591abc --- /dev/null +++ b/dnsforward/stats.go @@ -0,0 +1,91 @@ +package dnsforward + +import ( + "net" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/querylog" + "github.com/miekg/dns" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/stats" + "github.com/AdguardTeam/dnsproxy/proxy" +) + +// Write Stats data and logs +func processQueryLogsAndStats(ctx *dnsContext) int { + elapsed := time.Since(ctx.startTime) + s := ctx.srv + d := ctx.proxyCtx + + shouldLog := true + msg := d.Req + + // don't log ANY request if refuseAny is enabled + if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny { + shouldLog = false + } + + s.RLock() + // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. + // This can happen after proxy server has been stopped, but its workers haven't yet exited. + if shouldLog && s.queryLog != nil { + p := querylog.AddParams{ + Question: msg, + Answer: d.Res, + OrigAnswer: ctx.origResp, + Result: ctx.result, + Elapsed: elapsed, + ClientIP: getIP(d.Addr), + } + if d.Upstream != nil { + p.Upstream = d.Upstream.Address() + } + s.queryLog.Add(p) + } + + s.updateStats(d, elapsed, *ctx.result) + s.RUnlock() + + return resultDone +} + +func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) { + if s.stats == nil { + return + } + + e := stats.Entry{} + e.Domain = strings.ToLower(d.Req.Question[0].Name) + e.Domain = e.Domain[:len(e.Domain)-1] // remove last "." + switch addr := d.Addr.(type) { + case *net.UDPAddr: + e.Client = addr.IP + case *net.TCPAddr: + e.Client = addr.IP + } + e.Time = uint32(elapsed / 1000) + e.Result = stats.RNotFiltered + + switch res.Reason { + + case dnsfilter.FilteredSafeBrowsing: + e.Result = stats.RSafeBrowsing + + case dnsfilter.FilteredParental: + e.Result = stats.RParental + + case dnsfilter.FilteredSafeSearch: + e.Result = stats.RSafeSearch + + case dnsfilter.FilteredBlackList: + fallthrough + case dnsfilter.FilteredInvalid: + fallthrough + case dnsfilter.FilteredBlockedService: + e.Result = stats.RFiltered + } + + s.stats.Update(e) +} diff --git a/dnsforward/util.go b/dnsforward/util.go new file mode 100644 index 00000000..f5c62cb8 --- /dev/null +++ b/dnsforward/util.go @@ -0,0 +1,76 @@ +package dnsforward + +import ( + "net" + "sort" + "strings" + + "github.com/AdguardTeam/golibs/utils" +) + +func stringArrayDup(a []string) []string { + a2 := make([]string, len(a)) + copy(a2, a) + return a2 +} + +// Get IP address from net.Addr object +// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: +// https://github.com/AdguardTeam/AdGuardHome/issues/1261 +func ipFromAddr(a net.Addr) string { + switch addr := a.(type) { + case *net.UDPAddr: + return addr.IP.String() + case *net.TCPAddr: + return addr.IP.String() + } + return "" +} + +// Get IP address from net.Addr +func getIP(addr net.Addr) net.IP { + switch addr := addr.(type) { + case *net.UDPAddr: + return addr.IP + case *net.TCPAddr: + return addr.IP + } + return nil +} + +// Find value in a sorted array +func findSorted(ar []string, val string) int { + i := sort.SearchStrings(ar, val) + if i == len(ar) || ar[i] != val { + return -1 + } + return i +} + +func isWildcard(host string) bool { + return len(host) >= 2 && + host[0] == '*' && host[1] == '.' +} + +// Return TRUE if host name matches a wildcard pattern +func matchDomainWildcard(host, wildcard string) bool { + return isWildcard(wildcard) && + strings.HasSuffix(host, wildcard[1:]) +} + +// Return TRUE if client's SNI value matches DNS names from certificate +func matchDNSName(dnsNames []string, sni string) bool { + if utils.IsValidHostname(sni) != nil { + return false + } + if findSorted(dnsNames, sni) != -1 { + return true + } + + for _, dn := range dnsNames { + if matchDomainWildcard(sni, dn) { + return true + } + } + return false +} diff --git a/go.mod b/go.mod index 85c12c02..b33129c5 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.14 require ( - github.com/AdguardTeam/dnsproxy v0.27.1 + github.com/AdguardTeam/dnsproxy v0.28.0 github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/urlfilter v0.10.0 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index 636d419e..14b7faba 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.27.1 h1:CQ3vtGSNbHNeYkxC6pALwugTSssP2MnsjdxkvVMzEp4= -github.com/AdguardTeam/dnsproxy v0.27.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= +github.com/AdguardTeam/dnsproxy v0.28.0 h1:w6ITGjSMLztUOTVNVVcE0JU1bV2U0bOPyDHGwyZgTc4= +github.com/AdguardTeam/dnsproxy v0.28.0/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=