Merge: * dnsfilter: major refactoring

Close #928

* commit '31ec7f9652c6fb5947eca9e96f1507aad3e5ed13':
  * doc: new arch picture
  * dnsfilter: major refactoring
This commit is contained in:
Simon Zolin 2019-10-10 12:33:18 +03:00
commit 51f9d7e4df
15 changed files with 578 additions and 588 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 84 KiB

View File

@ -14,6 +14,7 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -66,7 +67,7 @@ type Config struct {
UsePlainHTTP bool `yaml:"-"` // use plain HTTP for requests to parental and safe browsing servers UsePlainHTTP bool `yaml:"-"` // use plain HTTP for requests to parental and safe browsing servers
SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
ResolverAddress string // DNS server address ResolverAddress string `yaml:"-"` // DNS server address
SafeBrowsingCacheSize uint `yaml:"safebrowsing_cache_size"` // (in bytes) SafeBrowsingCacheSize uint `yaml:"safebrowsing_cache_size"` // (in bytes)
SafeSearchCacheSize uint `yaml:"safesearch_cache_size"` // (in bytes) SafeSearchCacheSize uint `yaml:"safesearch_cache_size"` // (in bytes)
@ -75,13 +76,11 @@ type Config struct {
Rewrites []RewriteEntry `yaml:"rewrites"` Rewrites []RewriteEntry `yaml:"rewrites"`
// Filtering callback function // Called when the configuration is changed by HTTP request
FilterHandler func(clientAddr string, settings *RequestFilteringSettings) `yaml:"-"` ConfigModified func() `yaml:"-"`
}
type privateConfig struct { // Register an HTTP handler
parentalServer string // access via methods HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
safeBrowsingServer string // access via methods
} }
// LookupStats store stats collected during safebrowsing or parental checks // LookupStats store stats collected during safebrowsing or parental checks
@ -99,17 +98,30 @@ type Stats struct {
Safesearch LookupStats Safesearch LookupStats
} }
// Parameters to pass to filters-initializer goroutine
type filtersInitializerParams struct {
filters map[int]string
}
// Dnsfilter holds added rules and performs hostname matches against the rules // Dnsfilter holds added rules and performs hostname matches against the rules
type Dnsfilter struct { type Dnsfilter struct {
rulesStorage *urlfilter.RuleStorage rulesStorage *urlfilter.RuleStorage
filteringEngine *urlfilter.DNSEngine filteringEngine *urlfilter.DNSEngine
engineLock sync.RWMutex
// HTTP lookups for safebrowsing and parental // HTTP lookups for safebrowsing and parental
client http.Client // handle for http client -- single instance as recommended by docs client http.Client // handle for http client -- single instance as recommended by docs
transport *http.Transport // handle for http transport used by http client transport *http.Transport // handle for http transport used by http client
parentalServer string // access via methods
safeBrowsingServer string // access via methods
Config // for direct access by library users, even a = assignment Config // for direct access by library users, even a = assignment
privateConfig confLock sync.RWMutex
// Channel for passing data to filters-initializer goroutine
filtersInitializerChan chan filtersInitializerParams
filtersInitializerLock sync.Mutex
} }
// Filter represents a filter list // Filter represents a filter list
@ -119,8 +131,6 @@ type Filter struct {
FilePath string `yaml:"-"` // Path to a filtering rules file FilePath string `yaml:"-"` // Path to a filtering rules file
} }
//go:generate stringer -type=Reason
// Reason holds an enum detailing why it was filtered or not filtered // Reason holds an enum detailing why it was filtered or not filtered
type Reason int type Reason int
@ -153,8 +163,7 @@ const (
ReasonRewrite ReasonRewrite
) )
func (r Reason) String() string { var reasonNames = []string{
names := []string{
"NotFilteredNotFound", "NotFilteredNotFound",
"NotFilteredWhiteList", "NotFilteredWhiteList",
"NotFilteredError", "NotFilteredError",
@ -167,11 +176,86 @@ func (r Reason) String() string {
"FilteredBlockedService", "FilteredBlockedService",
"Rewrite", "Rewrite",
} }
if uint(r) >= uint(len(names)) {
func (r Reason) String() string {
if uint(r) >= uint(len(reasonNames)) {
return "" return ""
} }
return names[r] return reasonNames[r]
}
// GetConfig - get configuration
func (d *Dnsfilter) GetConfig() RequestFilteringSettings {
c := RequestFilteringSettings{}
// d.confLock.RLock()
c.SafeSearchEnabled = d.Config.SafeSearchEnabled
c.SafeBrowsingEnabled = d.Config.SafeBrowsingEnabled
c.ParentalEnabled = d.Config.ParentalEnabled
// d.confLock.RUnlock()
return c
}
// WriteDiskConfig - write configuration
func (d *Dnsfilter) WriteDiskConfig(c *Config) {
*c = d.Config
}
// SetFilters - set new filters (synchronously or asynchronously)
// When filters are set asynchronously, the old filters continue working until the new filters are ready.
// In this case the caller must ensure that the old filter files are intact.
func (d *Dnsfilter) SetFilters(filters map[int]string, async bool) error {
if async {
params := filtersInitializerParams{
filters: filters,
}
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
// remove all pending tasks
stop := false
for !stop {
select {
case <-d.filtersInitializerChan:
//
default:
stop = true
}
}
d.filtersInitializerChan <- params
d.filtersInitializerLock.Unlock()
return nil
}
err := d.initFiltering(filters)
if err != nil {
log.Error("Can't initialize filtering subsystem: %s", err)
return err
}
return nil
}
// Starts initializing new filters by signal from channel
func (d *Dnsfilter) filtersInitializer() {
for {
params := <-d.filtersInitializerChan
err := d.initFiltering(params.filters)
if err != nil {
log.Error("Can't initialize filtering subsystem: %s", err)
continue
}
}
}
// Close - close the object
func (d *Dnsfilter) Close() {
if d != nil && d.transport != nil {
d.transport.CloseIdleConnections()
}
if d.rulesStorage != nil {
d.rulesStorage.Close()
}
} }
type dnsFilterContext struct { type dnsFilterContext struct {
@ -294,6 +378,9 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result {
var res Result var res Result
d.confLock.RLock()
defer d.confLock.RUnlock()
for _, r := range d.Rewrites { for _, r := range d.Rewrites {
if r.Domain != host { if r.Domain != host {
continue continue
@ -704,17 +791,28 @@ func (d *Dnsfilter) initFiltering(filters map[int]string) error {
listArray = append(listArray, list) listArray = append(listArray, list)
} }
var err error rulesStorage, err := urlfilter.NewRuleStorage(listArray)
d.rulesStorage, err = urlfilter.NewRuleStorage(listArray)
if err != nil { if err != nil {
return fmt.Errorf("urlfilter.NewRuleStorage(): %s", err) return fmt.Errorf("urlfilter.NewRuleStorage(): %s", err)
} }
d.filteringEngine = urlfilter.NewDNSEngine(d.rulesStorage) filteringEngine := urlfilter.NewDNSEngine(rulesStorage)
d.engineLock.Lock()
if d.rulesStorage != nil {
d.rulesStorage.Close()
}
d.rulesStorage = rulesStorage
d.filteringEngine = filteringEngine
d.engineLock.Unlock()
log.Debug("initialized filtering engine")
return nil return nil
} }
// matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups // matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups
func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) {
d.engineLock.RLock()
defer d.engineLock.RUnlock()
if d.filteringEngine == nil { if d.filteringEngine == nil {
return Result{}, nil return Result{}, nil
} }
@ -926,27 +1024,21 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
err := d.initFiltering(filters) err := d.initFiltering(filters)
if err != nil { if err != nil {
log.Error("Can't initialize filtering subsystem: %s", err) log.Error("Can't initialize filtering subsystem: %s", err)
d.Destroy() d.Close()
return nil return nil
} }
} }
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
go d.filtersInitializer()
if d.Config.HTTPRegister != nil { // for tests
d.registerSecurityHandlers()
d.registerRewritesHandlers()
}
return d return d
} }
// Destroy is optional if you want to tidy up goroutines without waiting for them to die off
// right now it closes idle HTTP connections if there are any
func (d *Dnsfilter) Destroy() {
if d != nil && d.transport != nil {
d.transport.CloseIdleConnections()
}
if d.rulesStorage != nil {
d.rulesStorage.Close()
d.rulesStorage = nil
}
}
// //
// config manipulation helpers // config manipulation helpers
// //

View File

@ -108,7 +108,7 @@ func TestEtcHostsMatching(t *testing.T) {
filters := make(map[int]string) filters := make(map[int]string)
filters[0] = text filters[0] = text
d := NewForTest(nil, filters) d := NewForTest(nil, filters)
defer d.Destroy() defer d.Close()
d.checkMatchIP(t, "google.com", addr, dns.TypeA) d.checkMatchIP(t, "google.com", addr, dns.TypeA)
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
@ -133,7 +133,7 @@ func TestSafeBrowsing(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Destroy() defer d.Close()
gctx.stats.Safebrowsing.Requests = 0 gctx.stats.Safebrowsing.Requests = 0
d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru")
d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru")
@ -158,7 +158,7 @@ func TestSafeBrowsing(t *testing.T) {
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Destroy() defer d.Close()
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
@ -175,7 +175,7 @@ func TestParallelSB(t *testing.T) {
// the only way to verify that custom server option is working is to point it at a server that does serve safebrowsing // the only way to verify that custom server option is working is to point it at a server that does serve safebrowsing
func TestSafeBrowsingCustomServerFail(t *testing.T) { func TestSafeBrowsingCustomServerFail(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Destroy() defer d.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Write("Hello, client") // w.Write("Hello, client")
fmt.Fprintln(w, "Hello, client") fmt.Fprintln(w, "Hello, client")
@ -192,14 +192,14 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) {
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
d := NewForTest(nil, nil) d := NewForTest(nil, nil)
defer d.Destroy() defer d.Close()
_, ok := d.SafeSearchDomain("www.google.com") _, ok := d.SafeSearchDomain("www.google.com")
if ok { if ok {
t.Errorf("Expected safesearch to error when disabled") t.Errorf("Expected safesearch to error when disabled")
} }
d = NewForTest(&Config{SafeSearchEnabled: true}, nil) d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
if !ok { if !ok {
t.Errorf("Expected safesearch to find result for www.google.com") t.Errorf("Expected safesearch to find result for www.google.com")
@ -211,7 +211,7 @@ func TestSafeSearch(t *testing.T) {
func TestCheckHostSafeSearchYandex(t *testing.T) { func TestCheckHostSafeSearchYandex(t *testing.T) {
d := NewForTest(&Config{SafeSearchEnabled: true}, nil) d := NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
// Slice of yandex domains // Slice of yandex domains
yandex := []string{"yAndeX.ru", "YANdex.COM", "yandex.ua", "yandex.by", "yandex.kz", "www.yandex.com"} yandex := []string{"yAndeX.ru", "YANdex.COM", "yandex.ua", "yandex.by", "yandex.kz", "www.yandex.com"}
@ -231,7 +231,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) {
d := NewForTest(&Config{SafeSearchEnabled: true}, nil) d := NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
// Slice of google domains // Slice of google domains
googleDomains := []string{"www.google.com", "www.google.im", "www.google.co.in", "www.google.iq", "www.google.is", "www.google.it", "www.google.je"} googleDomains := []string{"www.google.com", "www.google.im", "www.google.co.in", "www.google.iq", "www.google.is", "www.google.it", "www.google.je"}
@ -251,7 +251,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
func TestSafeSearchCacheYandex(t *testing.T) { func TestSafeSearchCacheYandex(t *testing.T) {
d := NewForTest(nil, nil) d := NewForTest(nil, nil)
defer d.Destroy() defer d.Close()
domain := "yandex.ru" domain := "yandex.ru"
var result Result var result Result
@ -267,7 +267,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
} }
d = NewForTest(&Config{SafeSearchEnabled: true}, nil) d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
result, err = d.CheckHost(domain, dns.TypeA, &setts) result, err = d.CheckHost(domain, dns.TypeA, &setts)
if err != nil { if err != nil {
@ -293,7 +293,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
func TestSafeSearchCacheGoogle(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) {
d := NewForTest(nil, nil) d := NewForTest(nil, nil)
defer d.Destroy() defer d.Close()
domain := "www.google.ru" domain := "www.google.ru"
result, err := d.CheckHost(domain, dns.TypeA, &setts) result, err := d.CheckHost(domain, dns.TypeA, &setts)
if err != nil { if err != nil {
@ -304,7 +304,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
} }
d = NewForTest(&Config{SafeSearchEnabled: true}, nil) d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
// Let's lookup for safesearch domain // Let's lookup for safesearch domain
safeDomain, ok := d.SafeSearchDomain(domain) safeDomain, ok := d.SafeSearchDomain(domain)
@ -352,7 +352,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
func TestParentalControl(t *testing.T) { func TestParentalControl(t *testing.T) {
d := NewForTest(&Config{ParentalEnabled: true}, nil) d := NewForTest(&Config{ParentalEnabled: true}, nil)
defer d.Destroy() defer d.Close()
d.ParentalSensitivity = 3 d.ParentalSensitivity = 3
d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com")
d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com")
@ -435,7 +435,7 @@ func TestMatching(t *testing.T) {
filters := make(map[int]string) filters := make(map[int]string)
filters[0] = test.rules filters[0] = test.rules
d := NewForTest(nil, filters) d := NewForTest(nil, filters)
defer d.Destroy() defer d.Close()
ret, err := d.CheckHost(test.hostname, dns.TypeA, &setts) ret, err := d.CheckHost(test.hostname, dns.TypeA, &setts)
if err != nil { if err != nil {
@ -472,7 +472,7 @@ func TestClientSettings(t *testing.T) {
filters := make(map[int]string) filters := make(map[int]string)
filters[0] = "||example.org^\n" filters[0] = "||example.org^\n"
d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters) d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters)
defer d.Destroy() defer d.Close()
d.ParentalSensitivity = 3 d.ParentalSensitivity = 3
// no client settings: // no client settings:
@ -529,7 +529,7 @@ func TestClientSettings(t *testing.T) {
func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Destroy() defer d.Close()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
hostname := "wmconvirus.narod.ru" hostname := "wmconvirus.narod.ru"
ret, err := d.CheckHost(hostname, dns.TypeA, &setts) ret, err := d.CheckHost(hostname, dns.TypeA, &setts)
@ -544,7 +544,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Destroy() defer d.Close()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
hostname := "wmconvirus.narod.ru" hostname := "wmconvirus.narod.ru"
@ -561,7 +561,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) {
d := NewForTest(&Config{SafeSearchEnabled: true}, nil) d := NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
if !ok { if !ok {
@ -575,7 +575,7 @@ func BenchmarkSafeSearch(b *testing.B) {
func BenchmarkSafeSearchParallel(b *testing.B) { func BenchmarkSafeSearchParallel(b *testing.B) {
d := NewForTest(&Config{SafeSearchEnabled: true}, nil) d := NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Destroy() defer d.Close()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")

93
dnsfilter/rewrites.go Normal file
View File

@ -0,0 +1,93 @@
// DNS Rewrites
package dnsfilter
import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/golibs/log"
)
type rewriteEntryJSON struct {
Domain string `json:"domain"`
Answer string `json:"answer"`
}
func (d *Dnsfilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
arr := []*rewriteEntryJSON{}
d.confLock.Lock()
for _, ent := range d.Config.Rewrites {
jsent := rewriteEntryJSON{
Domain: ent.Domain,
Answer: ent.Answer,
}
arr = append(arr, &jsent)
}
d.confLock.Unlock()
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(arr)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
func (d *Dnsfilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ent := RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
d.confLock.Lock()
d.Config.Rewrites = append(d.Config.Rewrites, ent)
d.confLock.Unlock()
log.Debug("Rewrites: added element: %s -> %s [%d]",
ent.Domain, ent.Answer, len(d.Config.Rewrites))
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
entDel := RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
arr := []RewriteEntry{}
d.confLock.Lock()
for _, ent := range d.Config.Rewrites {
if ent == entDel {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
}
arr = append(arr, ent)
}
d.Config.Rewrites = arr
d.confLock.Unlock()
d.Config.ConfigModified()
}
func (d *Dnsfilter) registerRewritesHandlers() {
d.Config.HTTPRegister("GET", "/control/rewrite/list", d.handleRewriteList)
d.Config.HTTPRegister("POST", "/control/rewrite/add", d.handleRewriteAdd)
d.Config.HTTPRegister("POST", "/control/rewrite/delete", d.handleRewriteDelete)
}

179
dnsfilter/security.go Normal file
View File

@ -0,0 +1,179 @@
// Parental Control, Safe Browsing, Safe Search
package dnsfilter
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/AdguardTeam/golibs/log"
)
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (d *Dnsfilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeBrowsingEnabled = true
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeBrowsingEnabled = false
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeBrowsingEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
func parseParametersFromBody(r io.Reader) (map[string]string, error) {
parameters := map[string]string{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 {
// skip empty lines
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return parameters, errors.New("Got invalid request body")
}
parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
return parameters, nil
}
func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
sensitivity, ok := parameters["sensitivity"]
if !ok {
http.Error(w, "Sensitivity parameter was not specified", 400)
return
}
switch sensitivity {
case "3":
break
case "EARLY_CHILDHOOD":
sensitivity = "3"
case "10":
break
case "YOUNG":
sensitivity = "10"
case "13":
break
case "TEEN":
sensitivity = "13"
case "17":
break
case "MATURE":
sensitivity = "17"
default:
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
i, err := strconv.Atoi(sensitivity)
if err != nil {
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
d.Config.ParentalSensitivity = i
d.Config.ParentalEnabled = true
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) {
d.Config.ParentalEnabled = false
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.ParentalEnabled,
}
if d.Config.ParentalEnabled {
data["sensitivity"] = d.Config.ParentalSensitivity
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = true
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
d.Config.SafeSearchEnabled = false
d.Config.ConfigModified()
}
func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
func (d *Dnsfilter) registerSecurityHandlers() {
d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus)
}

View File

@ -3,7 +3,6 @@ package dnsforward
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -44,12 +43,6 @@ type Server struct {
queryLog querylog.QueryLog // Query log instance queryLog querylog.QueryLog // Query log instance
stats stats.Stats stats stats.Stats
// How many times the server was started
// While creating a dnsfilter object,
// we use this value to set s.dnsFilter property only with the most recent settings.
startCounter uint32
dnsfilterCreatorChan chan dnsfilterCreatorParams
AllowedClients map[string]bool // IP addresses of whitelist clients AllowedClients map[string]bool // IP addresses of whitelist clients
DisallowedClients map[string]bool // IP addresses of clients that should be blocked DisallowedClients map[string]bool // IP addresses of clients that should be blocked
AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
@ -60,15 +53,11 @@ type Server struct {
conf ServerConfig conf ServerConfig
} }
type dnsfilterCreatorParams struct {
conf dnsfilter.Config
filters map[int]string
}
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server { func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog querylog.QueryLog) *Server {
s := &Server{} s := &Server{}
s.dnsFilter = dnsFilter
s.stats = stats s.stats = stats
s.queryLog = queryLog s.queryLog = queryLog
return s return s
@ -76,6 +65,7 @@ func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server {
func (s *Server) Close() { func (s *Server) Close() {
s.Lock() s.Lock()
s.dnsFilter = nil
s.stats = nil s.stats = nil
s.queryLog = nil s.queryLog = nil
s.Unlock() s.Unlock()
@ -84,11 +74,8 @@ func (s *Server) Close() {
// FilteringConfig represents the DNS filtering configuration of AdGuard Home // FilteringConfig represents the DNS filtering configuration of AdGuard Home
// The zero FilteringConfig is empty and ready for use. // The zero FilteringConfig is empty and ready for use.
type FilteringConfig struct { type FilteringConfig struct {
// Create dnsfilter asynchronously. // Filtering callback function
// Requests won't be filtered until dnsfilter is created. FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// If "restart" command is received while we're creating an old dnsfilter object,
// we delay creation of the new object until the old one is created.
AsyncStartup bool `yaml:"-"`
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
@ -117,7 +104,8 @@ type FilteringConfig struct {
BlockedServices []string `yaml:"blocked_services"` BlockedServices []string `yaml:"blocked_services"`
CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes)
dnsfilter.Config `yaml:",inline"`
DnsfilterConf dnsfilter.Config `yaml:",inline"`
} }
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
@ -140,7 +128,6 @@ type ServerConfig struct {
TCPListenAddr *net.TCPAddr // TCP listen address TCPListenAddr *net.TCPAddr // TCP listen address
Upstreams []upstream.Upstream // Configured upstreams Upstreams []upstream.Upstream // Configured upstreams
DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams
Filters []dnsfilter.Filter // A list of filters to use
OnDNSRequest func(d *proxy.DNSContext) OnDNSRequest func(d *proxy.DNSContext)
FilteringConfig FilteringConfig
@ -204,13 +191,18 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
// startInternal starts without locking // startInternal starts without locking
func (s *Server) startInternal(config *ServerConfig) error { func (s *Server) startInternal(config *ServerConfig) error {
if s.dnsFilter != nil || s.dnsProxy != nil { if s.dnsProxy != nil {
return errors.New("DNS server is already started") return errors.New("DNS server is already started")
} }
err := s.initDNSFilter(config) if config != nil {
if err != nil { s.conf = *config
return err }
if len(s.conf.ParentalBlockHost) == 0 {
s.conf.ParentalBlockHost = parentalBlockHost
}
if len(s.conf.SafeBrowsingBlockHost) == 0 {
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
} }
proxyConfig := proxy.Config{ proxyConfig := proxy.Config{
@ -228,7 +220,7 @@ func (s *Server) startInternal(config *ServerConfig) error {
AllServers: s.conf.AllServers, AllServers: s.conf.AllServers,
} }
err = processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) err := processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients)
if err != nil { if err != nil {
return err return err
} }
@ -269,97 +261,6 @@ func (s *Server) startInternal(config *ServerConfig) error {
return s.dnsProxy.Start() return s.dnsProxy.Start()
} }
// Initializes the DNS filter
func (s *Server) initDNSFilter(config *ServerConfig) error {
if config != nil {
s.conf = *config
}
var filters map[int]string
filters = nil
if s.conf.FilteringEnabled {
filters = make(map[int]string)
for _, f := range s.conf.Filters {
if f.ID == 0 {
filters[int(f.ID)] = string(f.Data)
} else {
filters[int(f.ID)] = f.FilePath
}
}
}
if len(s.conf.ParentalBlockHost) == 0 {
s.conf.ParentalBlockHost = parentalBlockHost
}
if len(s.conf.SafeBrowsingBlockHost) == 0 {
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
}
if s.conf.AsyncStartup {
params := dnsfilterCreatorParams{
conf: s.conf.Config,
filters: filters,
}
s.startCounter++
if s.startCounter == 1 {
s.dnsfilterCreatorChan = make(chan dnsfilterCreatorParams, 1)
go s.dnsfilterCreator()
}
// remove all pending tasks
stop := false
for !stop {
select {
case <-s.dnsfilterCreatorChan:
//
default:
stop = true
}
}
s.dnsfilterCreatorChan <- params
} else {
log.Debug("creating dnsfilter...")
f := dnsfilter.New(&s.conf.Config, filters)
if f == nil {
return fmt.Errorf("could not initialize dnsfilter")
}
log.Debug("created dnsfilter")
s.dnsFilter = f
}
return nil
}
func (s *Server) dnsfilterCreator() {
for {
params := <-s.dnsfilterCreatorChan
s.Lock()
counter := s.startCounter
s.Unlock()
log.Debug("creating dnsfilter...")
f := dnsfilter.New(&params.conf, params.filters)
if f == nil {
log.Error("could not initialize dnsfilter")
continue
}
set := false
s.Lock()
if counter == s.startCounter {
s.dnsFilter = f
set = true
}
s.Unlock()
if set {
log.Debug("created and activated dnsfilter")
} else {
log.Debug("created dnsfilter")
}
}
}
// Stop stops the DNS server // Stop stops the DNS server
func (s *Server) Stop() error { func (s *Server) Stop() error {
s.Lock() s.Lock()
@ -377,11 +278,6 @@ func (s *Server) stopInternal() error {
} }
} }
if s.dnsFilter != nil {
s.dnsFilter.Destroy()
s.dnsFilter = nil
}
return nil return nil
} }
@ -607,33 +503,24 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
var res dnsfilter.Result
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
dnsFilter := s.dnsFilter
if !s.conf.ProtectionEnabled || s.dnsFilter == nil { if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
return &dnsfilter.Result{}, nil return &dnsfilter.Result{}, nil
} }
var err error
clientAddr := "" clientAddr := ""
if d.Addr != nil { if d.Addr != nil {
clientAddr, _, _ = net.SplitHostPort(d.Addr.String()) clientAddr, _, _ = net.SplitHostPort(d.Addr.String())
} }
var setts dnsfilter.RequestFilteringSettings setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
setts.SafeSearchEnabled = s.conf.SafeSearchEnabled
setts.SafeBrowsingEnabled = s.conf.SafeBrowsingEnabled
setts.ParentalEnabled = s.conf.ParentalEnabled
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
s.conf.FilterHandler(clientAddr, &setts) s.conf.FilterHandler(clientAddr, &setts)
} }
res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts)
if err != nil { if err != nil {
// Return immediately if there's an error // Return immediately if there's an error
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)

View File

@ -148,7 +148,6 @@ func TestServerRace(t *testing.T) {
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
s.conf.SafeSearchEnabled = true
err := s.Start(nil) err := s.Start(nil)
if err != nil { if err != nil {
t.Fatalf("Failed to start server: %s", err) t.Fatalf("Failed to start server: %s", err)
@ -376,23 +375,24 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
} }
func createTestServer(t *testing.T) *Server { func createTestServer(t *testing.T) *Server {
s := NewServer(nil, nil) rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n"
filters := map[int]string{}
filters[0] = rules
c := dnsfilter.Config{}
c.SafeBrowsingEnabled = true
c.SafeBrowsingCacheSize = 1000
c.SafeSearchEnabled = true
c.SafeSearchCacheSize = 1000
c.ParentalCacheSize = 1000
c.CacheTime = 30
f := dnsfilter.New(&c, filters)
s := NewServer(f, nil, nil)
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.FilteringConfig.FilteringEnabled = true s.conf.FilteringConfig.FilteringEnabled = true
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
s.conf.FilteringConfig.SafeBrowsingEnabled = true
s.conf.Filters = make([]dnsfilter.Filter, 0)
s.conf.SafeBrowsingCacheSize = 1000
s.conf.SafeSearchCacheSize = 1000
s.conf.ParentalCacheSize = 1000
s.conf.CacheTime = 30
rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n"
filter := dnsfilter.Filter{ID: 0, Data: []byte(rules)}
s.conf.Filters = append(s.conf.Filters, filter)
return s return s
} }

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/stats"
@ -71,7 +72,6 @@ type configuration struct {
client *http.Client client *http.Client
stats stats.Stats // statistics module stats stats.Stats // statistics module
queryLog querylog.QueryLog // query log module queryLog querylog.QueryLog // query log module
filteringStarted bool // TRUE if filtering module is started
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
// cached version.json to avoid hammering github.io for each page reload // cached version.json to avoid hammering github.io for each page reload
@ -79,6 +79,7 @@ type configuration struct {
versionCheckLastTime time.Time versionCheckLastTime time.Time
dnsctx dnsContext dnsctx dnsContext
dnsFilter *dnsfilter.Dnsfilter
dnsServer *dnsforward.Server dnsServer *dnsforward.Server
dhcpServer dhcpd.Server dhcpServer dhcpd.Server
httpServer *http.Server httpServer *http.Server
@ -217,10 +218,10 @@ func initConfig() {
} }
config.DNS.CacheSize = 4 * 1024 * 1024 config.DNS.CacheSize = 4 * 1024 * 1024
config.DNS.SafeBrowsingCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
config.DNS.SafeSearchCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.ParentalCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.CacheTime = 30 config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters() config.Filters = defaultFilters()
} }
@ -367,6 +368,12 @@ func (c *configuration) write() error {
config.DNS.QueryLogInterval = dc.Interval config.DNS.QueryLogInterval = dc.Interval
} }
if config.dnsFilter != nil {
c := dnsfilter.Config{}
config.dnsFilter.WriteDiskConfig(&c)
config.DNS.DnsfilterConf = c
}
configFile := config.getConfigFilename() configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile) log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config) yamlText, err := yaml.Marshal(&config)

View File

@ -377,142 +377,6 @@ func checkDNS(input string, bootstrap []string) error {
return nil return nil
} }
// ------------
// safebrowsing
// ------------
func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
config.DNS.SafeBrowsingEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
config.DNS.SafeBrowsingEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.DNS.SafeBrowsingEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
// --------
// parental
// --------
func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err)
return
}
sensitivity, ok := parameters["sensitivity"]
if !ok {
http.Error(w, "Sensitivity parameter was not specified", 400)
return
}
switch sensitivity {
case "3":
break
case "EARLY_CHILDHOOD":
sensitivity = "3"
case "10":
break
case "YOUNG":
sensitivity = "10"
case "13":
break
case "TEEN":
sensitivity = "13"
case "17":
break
case "MATURE":
sensitivity = "17"
default:
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
i, err := strconv.Atoi(sensitivity)
if err != nil {
http.Error(w, "Sensitivity must be set to valid value", 400)
return
}
config.DNS.ParentalSensitivity = i
config.DNS.ParentalEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleParentalDisable(w http.ResponseWriter, r *http.Request) {
config.DNS.ParentalEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.DNS.ParentalEnabled,
}
if config.DNS.ParentalEnabled {
data["sensitivity"] = config.DNS.ParentalSensitivity
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
// ------------
// safebrowsing
// ------------
func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
config.DNS.SafeSearchEnabled = true
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
config.DNS.SafeSearchEnabled = false
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": config.DNS.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
}
// -------------- // --------------
// DNS-over-HTTPS // DNS-over-HTTPS
// -------------- // --------------
@ -543,15 +407,6 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate) httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister(http.MethodPost, "/control/safebrowsing/enable", handleSafeBrowsingEnable)
httpRegister(http.MethodPost, "/control/safebrowsing/disable", handleSafeBrowsingDisable)
httpRegister(http.MethodGet, "/control/safebrowsing/status", handleSafeBrowsingStatus)
httpRegister(http.MethodPost, "/control/parental/enable", handleParentalEnable)
httpRegister(http.MethodPost, "/control/parental/disable", handleParentalDisable)
httpRegister(http.MethodGet, "/control/parental/status", handleParentalStatus)
httpRegister(http.MethodPost, "/control/safesearch/enable", handleSafeSearchEnable)
httpRegister(http.MethodPost, "/control/safesearch/disable", handleSafeSearchDisable)
httpRegister(http.MethodGet, "/control/safesearch/status", handleSafeSearchStatus)
httpRegister(http.MethodGet, "/control/dhcp/status", handleDHCPStatus) httpRegister(http.MethodGet, "/control/dhcp/status", handleDHCPStatus)
httpRegister(http.MethodGet, "/control/dhcp/interfaces", handleDHCPInterfaces) httpRegister(http.MethodGet, "/control/dhcp/interfaces", handleDHCPInterfaces)
httpRegister(http.MethodPost, "/control/dhcp/set_config", handleDHCPSetConfig) httpRegister(http.MethodPost, "/control/dhcp/set_config", handleDHCPSetConfig)
@ -565,7 +420,6 @@ func registerControlHandlers() {
RegisterFilteringHandlers() RegisterFilteringHandlers()
RegisterTLSHandlers() RegisterTLSHandlers()
RegisterClientsHandlers() RegisterClientsHandlers()
registerRewritesHandlers()
RegisterBlockedServicesHandlers() RegisterBlockedServicesHandlers()
RegisterAuthHandlers() RegisterAuthHandlers()

View File

@ -86,17 +86,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
return return
} }
err = writeAllConfigs() onConfigModified()
if err != nil { enableFilters(true)
httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err)
return
}
err = reconfigureDNSServer()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err)
return
}
_, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount)
if err != nil { if err != nil {
@ -121,32 +112,28 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
return return
} }
// Stop DNS server:
// we close urlfilter object which in turn closes file descriptors to filter files.
// Otherwise, Windows won't allow us to remove the file which is being currently used.
_ = config.dnsServer.Stop()
// go through each element and delete if url matches // go through each element and delete if url matches
config.Lock() config.Lock()
newFilters := config.Filters[:0] newFilters := []filter{}
for _, filter := range config.Filters { for _, filter := range config.Filters {
if filter.URL != req.URL { if filter.URL != req.URL {
newFilters = append(newFilters, filter) newFilters = append(newFilters, filter)
} else { } else {
// Remove the filter file err := os.Rename(filter.Path(), filter.Path()+".old")
err := os.Remove(filter.Path()) if err != nil {
if err != nil && !os.IsNotExist(err) { log.Error("os.Rename: %s: %s", filter.Path(), err)
config.Unlock()
httpError(w, http.StatusInternalServerError, "Couldn't remove the filter file: %s", err)
return
} }
log.Debug("os.Remove(%s)", filter.Path())
} }
} }
// Update the configuration after removing filter files // Update the configuration after removing filter files
config.Filters = newFilters config.Filters = newFilters
config.Unlock() config.Unlock()
httpUpdateConfigReloadDNSReturnOK(w, r)
onConfigModified()
enableFilters(true)
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
} }
type filterURLJSON struct { type filterURLJSON struct {
@ -173,7 +160,8 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
return return
} }
httpUpdateConfigReloadDNSReturnOK(w, r) onConfigModified()
enableFilters(true)
} }
func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
@ -184,12 +172,13 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
} }
config.UserRules = strings.Split(string(body), "\n") config.UserRules = strings.Split(string(body), "\n")
httpUpdateConfigReloadDNSReturnOK(w, r) _ = writeAllConfigs()
enableFilters(true)
} }
func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
updated := refreshFiltersIfNecessary(true) beginRefreshFilters()
fmt.Fprintf(w, "OK %d filters updated\n", updated) fmt.Fprintf(w, "OK 0 filters updated\n")
} }
type filterJSON struct { type filterJSON struct {
@ -260,9 +249,8 @@ func handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
config.DNS.FilteringEnabled = req.Enabled config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval config.DNS.FiltersUpdateIntervalHours = req.Interval
httpUpdateConfigReloadDNSReturnOK(w, r) onConfigModified()
enableFilters(true)
returnOK(w)
} }
// RegisterFilteringHandlers - register handlers // RegisterFilteringHandlers - register handlers

View File

@ -55,7 +55,18 @@ func initDNSServer() {
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
} }
config.queryLog = querylog.New(conf) config.queryLog = querylog.New(conf)
config.dnsServer = dnsforward.NewServer(config.stats, config.queryLog)
filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
config.dnsFilter = dnsfilter.New(&filterConf, nil)
config.dnsServer = dnsforward.NewServer(config.dnsFilter, config.stats, config.queryLog)
sessFilename := filepath.Join(baseDir, "sessions.db") sessFilename := filepath.Join(baseDir, "sessions.db")
config.auth = InitAuth(sessFilename, config.Users) config.auth = InitAuth(sessFilename, config.Users)
@ -159,34 +170,11 @@ func onDNSRequest(d *proxy.DNSContext) {
} }
func generateServerConfig() (dnsforward.ServerConfig, error) { func generateServerConfig() (dnsforward.ServerConfig, error) {
filters := []dnsfilter.Filter{}
userFilter := userFilter()
filters = append(filters, dnsfilter.Filter{
ID: userFilter.ID,
Data: userFilter.Data,
})
for _, filter := range config.Filters {
if !filter.Enabled {
continue
}
filters = append(filters, dnsfilter.Filter{
ID: filter.ID,
FilePath: filter.Path(),
})
}
newconfig := dnsforward.ServerConfig{ newconfig := dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig, FilteringConfig: config.DNS.FilteringConfig,
Filters: filters,
} }
newconfig.AsyncStartup = true
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
}
newconfig.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
if config.TLS.Enabled { if config.TLS.Enabled {
newconfig.TLSConfig = config.TLS.TLSConfig newconfig.TLSConfig = config.TLS.TLSConfig
@ -242,20 +230,18 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running") return fmt.Errorf("unable to start forwarding DNS server: Already running")
} }
enableFilters(false)
newconfig, err := generateServerConfig() newconfig, err := generateServerConfig()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
err = config.dnsServer.Start(&newconfig) err = config.dnsServer.Start(&newconfig)
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
} }
if !config.filteringStarted {
config.filteringStarted = true
startRefreshFilters()
}
return nil return nil
} }
@ -285,6 +271,9 @@ func stopDNSServer() error {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them // DNS forward module must be closed BEFORE stats or queryLog because it depends on them
config.dnsServer.Close() config.dnsServer.Close()
config.dnsFilter.Close()
config.dnsFilter = nil
config.stats.Close() config.stats.Close()
config.stats = nil config.stats = nil

View File

@ -1,104 +0,0 @@
package home
import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/golibs/log"
)
type rewriteEntryJSON struct {
Domain string `json:"domain"`
Answer string `json:"answer"`
}
func handleRewriteList(w http.ResponseWriter, r *http.Request) {
arr := []*rewriteEntryJSON{}
config.RLock()
for _, ent := range config.DNS.Rewrites {
jsent := rewriteEntryJSON{
Domain: ent.Domain,
Answer: ent.Answer,
}
arr = append(arr, &jsent)
}
config.RUnlock()
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(arr)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
}
func handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ent := dnsfilter.RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
config.Lock()
config.DNS.Rewrites = append(config.DNS.Rewrites, ent)
config.Unlock()
log.Debug("Rewrites: added element: %s -> %s [%d]",
ent.Domain, ent.Answer, len(config.DNS.Rewrites))
err = writeAllConfigsAndReloadDNS()
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
returnOK(w)
}
func handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
httpError(w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
entDel := dnsfilter.RewriteEntry{
Domain: jsent.Domain,
Answer: jsent.Answer,
}
arr := []dnsfilter.RewriteEntry{}
config.Lock()
for _, ent := range config.DNS.Rewrites {
if ent == entDel {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
}
arr = append(arr, ent)
}
config.DNS.Rewrites = arr
config.Unlock()
err = writeAllConfigsAndReloadDNS()
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
}
returnOK(w)
}
func registerRewritesHandlers() {
httpRegister(http.MethodGet, "/control/rewrite/list", handleRewriteList)
httpRegister(http.MethodPost, "/control/rewrite/add", handleRewriteAdd)
httpRegister(http.MethodPost, "/control/rewrite/delete", handleRewriteDelete)
}

View File

@ -19,18 +19,13 @@ import (
var ( var (
nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
forceRefresh bool
) )
func initFiltering() { func initFiltering() {
loadFilters() loadFilters()
deduplicateFilters() deduplicateFilters()
updateUniqueFilterID(config.Filters) updateUniqueFilterID(config.Filters)
}
func startRefreshFilters() {
go func() {
_ = refreshFiltersIfNecessary(false)
}()
go periodicallyRefreshFilters() go periodicallyRefreshFilters()
} }
@ -180,14 +175,25 @@ func assignUniqueFilterID() int64 {
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func periodicallyRefreshFilters() { func periodicallyRefreshFilters() {
nextRefresh := int64(0)
for { for {
time.Sleep(1 * time.Hour) if forceRefresh {
if config.DNS.FiltersUpdateIntervalHours == 0 { _ = refreshFiltersIfNecessary(true)
continue forceRefresh = false
} }
refreshFiltersIfNecessary(false) if config.DNS.FiltersUpdateIntervalHours != 0 && nextRefresh <= time.Now().Unix() {
_ = refreshFiltersIfNecessary(false)
nextRefresh = time.Now().Add(1 * time.Hour).Unix()
} }
time.Sleep(1 * time.Second)
}
}
// Schedule the procedure to refresh filters
func beginRefreshFilters() {
forceRefresh = true
log.Debug("Filters: schedule update")
} }
// Checks filters updates if necessary // Checks filters updates if necessary
@ -196,16 +202,16 @@ func periodicallyRefreshFilters() {
// Algorithm: // Algorithm:
// . Get the list of filters to be updated // . Get the list of filters to be updated
// . For each filter run the download and checksum check operation // . For each filter run the download and checksum check operation
// . Stop server
// . For each filter: // . For each filter:
// . If filter data hasn't changed, just set new update time on file // . If filter data hasn't changed, just set new update time on file
// . If filter data has changed, save it on disk // . If filter data has changed: rename the old file, store the new data on disk
// . Apply changes to the current configuration // . Pass new filters to dnsfilter object
// . Start server
func refreshFiltersIfNecessary(force bool) int { func refreshFiltersIfNecessary(force bool) int {
var updateFilters []filter var updateFilters []filter
var updateFlags []bool // 'true' if filter data has changed var updateFlags []bool // 'true' if filter data has changed
log.Debug("Filters: updating...")
now := time.Now() now := time.Now()
config.RLock() config.RLock()
for i := range config.Filters { for i := range config.Filters {
@ -229,7 +235,6 @@ func refreshFiltersIfNecessary(force bool) int {
} }
config.RUnlock() config.RUnlock()
updateCount := 0
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated, err := uf.update() updated, err := uf.update()
@ -239,24 +244,14 @@ func refreshFiltersIfNecessary(force bool) int {
continue continue
} }
uf.LastUpdated = now uf.LastUpdated = now
if updated {
updateCount++
}
} }
stopped := false updateCount := 0
if updateCount != 0 {
_ = config.dnsServer.Stop()
stopped = true
}
updateCount = 0
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated := updateFlags[i] updated := updateFlags[i]
if updated { if updated {
// Saving it to the filters dir now err := uf.saveAndBackupOld()
err := uf.save()
if err != nil { if err != nil {
log.Printf("Failed to save the updated filter %d: %s", uf.ID, err) log.Printf("Failed to save the updated filter %d: %s", uf.ID, err)
continue continue
@ -290,12 +285,20 @@ func refreshFiltersIfNecessary(force bool) int {
config.Unlock() config.Unlock()
} }
if stopped { if updateCount != 0 {
err := reconfigureDNSServer() enableFilters(false)
if err != nil {
log.Error("cannot reconfigure DNS server with the new filters: %s", err) for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
if !updated {
continue
}
_ = os.Remove(uf.Path() + ".old")
} }
} }
log.Debug("Filters: update finished")
return updateCount return updateCount
} }
@ -413,6 +416,12 @@ func (filter *filter) save() error {
return err return err
} }
func (filter *filter) saveAndBackupOld() error {
filterFilePath := filter.Path()
_ = os.Rename(filterFilePath, filterFilePath+".old")
return filter.save()
}
// loads filter contents from the file in dataDir // loads filter contents from the file in dataDir
func (filter *filter) load() error { func (filter *filter) load() error {
filterFilePath := filter.Path() filterFilePath := filter.Path()
@ -467,3 +476,23 @@ func (filter *filter) LastTimeUpdated() time.Time {
// filter file modified time // filter file modified time
return s.ModTime() return s.ModTime()
} }
func enableFilters(async bool) {
var filters map[int]string
if config.DNS.FilteringConfig.FilteringEnabled {
// convert array of filters
filters = make(map[int]string)
userFilter := userFilter()
filters[int(userFilter.ID)] = string(userFilter.Data)
for _, filter := range config.Filters {
if !filter.Enabled {
continue
}
filters[int(filter.ID)] = filter.Path()
}
}
_ = config.dnsFilter.SetFilters(filters, async)
}

View File

@ -1,12 +1,10 @@
package home package home
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -155,29 +153,6 @@ func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler} return &postInstallHandlerStruct{handler}
} }
// -------------------------------------------------
// helper functions for parsing parameters from body
// -------------------------------------------------
func parseParametersFromBody(r io.Reader) (map[string]string, error) {
parameters := map[string]string{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 {
// skip empty lines
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
return parameters, errors.New("Got invalid request body")
}
parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
return parameters, nil
}
// ------------------ // ------------------
// network interfaces // network interfaces
// ------------------ // ------------------

View File

@ -143,11 +143,12 @@ func run(args options) {
} }
initDNSServer() initDNSServer()
go func() {
err = startDNSServer() err = startDNSServer()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
}()
err = startDHCPServer() err = startDHCPServer()
if err != nil { if err != nil {