diff --git a/control.go b/control.go index fe81eaa2..8ded7372 100644 --- a/control.go +++ b/control.go @@ -35,8 +35,13 @@ var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} const versionCheckURL = "https://adguardteam.github.io/AdGuardHome/version.json" const versionCheckPeriod = time.Hour * 8 +var transport = &http.Transport{ + DialContext: customDialContext, +} + var client = &http.Client{ - Timeout: time.Minute * 5, + Timeout: time.Minute * 5, + Transport: transport, } var controlLock sync.Mutex diff --git a/dns.go b/dns.go index 9abbc80e..b135babf 100644 --- a/dns.go +++ b/dns.go @@ -50,6 +50,7 @@ func generateServerConfig() dnsforward.ServerConfig { FilteringConfig: config.DNS.FilteringConfig, Filters: filters, } + newconfig.ResolverAddress = fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port) if config.TLS.Enabled { newconfig.TLSConfig = config.TLS.TLSConfig diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 83818f1b..c487d7d2 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -3,6 +3,7 @@ package dnsfilter import ( "bufio" "bytes" + "context" "crypto/sha256" "encoding/json" "errors" @@ -16,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/bluele/gcache" "golang.org/x/net/publicsuffix" @@ -45,10 +47,11 @@ const enableDelayedCompilation = true // flag for debugging, must be true in pro // Config allows you to configure DNS filtering with New() or just change variables directly. type Config struct { - ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 - ParentalEnabled bool `yaml:"parental_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` + ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 + ParentalEnabled bool `yaml:"parental_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` + ResolverAddress string // DNS server address } type privateConfig struct { @@ -159,6 +162,8 @@ var ( safeSearchCache gcache.Cache ) +var resolverAddr string // DNS server address + // Result holds state of hostname check type Result struct { IsFiltered bool `json:",omitempty"` // True if the host name is filtered @@ -971,6 +976,47 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // lifecycle helper functions // +// Connect to a remote server resolving hostname using our own DNS server +func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + r := upstream.NewResolver(resolverAddr, 30*time.Second) + addrs, e := r.LookupIPAddr(ctx, host) + log.Tracef("LookupIPAddr: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + var firstErr error + firstErr = nil + for _, a := range addrs { + addr = fmt.Sprintf("%s:%s", a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return con, err + } + return nil, firstErr +} + // New creates properly initialized DNS Filter that is ready to be used func New(c *Config) *Dnsfilter { d := new(Dnsfilter) @@ -990,6 +1036,10 @@ func New(c *Config) *Dnsfilter { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + if len(c.ResolverAddress) != 0 { + resolverAddr = c.ResolverAddress + d.transport.DialContext = customDialContext + } d.client = http.Client{ Transport: d.transport, Timeout: defaultHTTPTimeout, diff --git a/helpers.go b/helpers.go index a304634e..184789d4 100644 --- a/helpers.go +++ b/helpers.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "errors" "fmt" "io" @@ -14,7 +15,10 @@ import ( "runtime" "strconv" "strings" + "time" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -300,6 +304,48 @@ func checkPacketPortAvailable(host string, port int) error { return err } +// Connect to a remote server resolving hostname using our own DNS server +func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + log.Tracef("network:%v addr:%v", network, addr) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + dialer := &net.Dialer{ + Timeout: time.Minute * 5, + } + + if net.ParseIP(host) != nil { + con, err := dialer.DialContext(ctx, network, addr) + return con, err + } + + resolverAddr := fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port) + r := upstream.NewResolver(resolverAddr, 30*time.Second) + addrs, e := r.LookupIPAddr(ctx, host) + log.Tracef("LookupIPAddr: %s: %v", host, addrs) + if e != nil { + return nil, e + } + + var firstErr error + firstErr = nil + for _, a := range addrs { + addr = fmt.Sprintf("%s:%s", a.String(), port) + con, err := dialer.DialContext(ctx, network, addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + return con, err + } + return nil, firstErr +} + // --------------------- // debug logging helpers // ---------------------