Merge: dnsfilter: prevent recursion when both parental control and safebrowsing are enabled

Close #732

* commit 'c4e67690f4fcceb055cbea73610b5974855db96f':
  * dnsfilter: don't use global variable for custom resolver function
  - dnsfilter: prevent recursion when both parental control and safebrowsing are enabled
This commit is contained in:
Simon Zolin 2019-04-24 12:52:16 +03:00
commit e1bb89c393
1 changed files with 43 additions and 46 deletions

View File

@ -162,8 +162,6 @@ var (
safeSearchCache gcache.Cache safeSearchCache gcache.Cache
) )
var resolverAddr string // DNS server address
// Result holds state of hostname check // Result holds state of hostname check
type Result struct { type Result struct {
IsFiltered bool `json:",omitempty"` // True if the host name is filtered IsFiltered bool `json:",omitempty"` // True if the host name is filtered
@ -185,6 +183,10 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) {
return Result{Reason: NotFilteredNotFound}, nil return Result{Reason: NotFilteredNotFound}, nil
} }
host = strings.ToLower(host) host = strings.ToLower(host)
// prevent recursion
if host == d.parentalServer || host == d.safeBrowsingServer {
return Result{}, nil
}
// try filter lists first // try filter lists first
result, err := d.matchHost(host) result, err := d.matchHost(host)
@ -674,10 +676,6 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host) defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host)
} }
// prevent recursion -- checking the host of safebrowsing server makes no sense
if host == d.safeBrowsingServer {
return Result{}, nil
}
format := func(hashparam string) string { format := func(hashparam string) string {
url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam) url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam)
return url return url
@ -720,10 +718,6 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
defer timer.LogElapsed("Parental HTTP lookup for %s", host) defer timer.LogElapsed("Parental HTTP lookup for %s", host)
} }
// prevent recursion -- checking the host of parental safety server makes no sense
if host == d.parentalServer {
return Result{}, nil
}
format := func(hashparam string) string { format := func(hashparam string) string {
url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity) url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity)
return url return url
@ -978,8 +972,11 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
// lifecycle helper functions // lifecycle helper functions
// //
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
// Connect to a remote server resolving hostname using our own DNS server // Connect to a remote server resolving hostname using our own DNS server
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) { func createCustomDialContext(resolverAddr string) dialFunctionType {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr) log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
@ -1018,6 +1015,7 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
} }
return nil, firstErr return nil, firstErr
} }
}
// New creates properly initialized DNS Filter that is ready to be used // New creates properly initialized DNS Filter that is ready to be used
func New(c *Config) *Dnsfilter { func New(c *Config) *Dnsfilter {
@ -1039,8 +1037,7 @@ func New(c *Config) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
if c != nil && len(c.ResolverAddress) != 0 { if c != nil && len(c.ResolverAddress) != 0 {
resolverAddr = c.ResolverAddress d.transport.DialContext = createCustomDialContext(c.ResolverAddress)
d.transport.DialContext = customDialContext
} }
d.client = http.Client{ d.client = http.Client{
Transport: d.transport, Transport: d.transport,