diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 14c12844..2902c75a 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "strings" + "sync" "sync/atomic" "time" @@ -66,7 +67,7 @@ type Config struct { UsePlainHTTP bool `yaml:"-"` // use plain HTTP for requests to parental and safe browsing servers SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - ResolverAddress string // DNS server address + ResolverAddress string `yaml:"-"` // DNS server address SafeBrowsingCacheSize uint `yaml:"safebrowsing_cache_size"` // (in bytes) SafeSearchCacheSize uint `yaml:"safesearch_cache_size"` // (in bytes) @@ -75,13 +76,11 @@ type Config struct { Rewrites []RewriteEntry `yaml:"rewrites"` - // Filtering callback function - FilterHandler func(clientAddr string, settings *RequestFilteringSettings) `yaml:"-"` -} + // Called when the configuration is changed by HTTP request + ConfigModified func() `yaml:"-"` -type privateConfig struct { - parentalServer string // access via methods - safeBrowsingServer string // access via methods + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` } // LookupStats store stats collected during safebrowsing or parental checks @@ -99,17 +98,30 @@ type Stats struct { Safesearch LookupStats } +// Parameters to pass to filters-initializer goroutine +type filtersInitializerParams struct { + filters map[int]string +} + // Dnsfilter holds added rules and performs hostname matches against the rules type Dnsfilter struct { rulesStorage *urlfilter.RuleStorage filteringEngine *urlfilter.DNSEngine + engineLock sync.RWMutex // HTTP lookups for safebrowsing and parental client http.Client // handle for http client -- single instance as recommended by docs transport *http.Transport // handle for http transport used by http client - Config // for direct access by library users, even a = assignment - privateConfig + parentalServer string // access via methods + safeBrowsingServer string // access via methods + + Config // for direct access by library users, even a = assignment + confLock sync.RWMutex + + // Channel for passing data to filters-initializer goroutine + filtersInitializerChan chan filtersInitializerParams + filtersInitializerLock sync.Mutex } // Filter represents a filter list @@ -119,8 +131,6 @@ type Filter struct { FilePath string `yaml:"-"` // Path to a filtering rules file } -//go:generate stringer -type=Reason - // Reason holds an enum detailing why it was filtered or not filtered type Reason int @@ -153,25 +163,99 @@ const ( ReasonRewrite ) +var reasonNames = []string{ + "NotFilteredNotFound", + "NotFilteredWhiteList", + "NotFilteredError", + + "FilteredBlackList", + "FilteredSafeBrowsing", + "FilteredParental", + "FilteredInvalid", + "FilteredSafeSearch", + "FilteredBlockedService", + + "Rewrite", +} + func (r Reason) String() string { - names := []string{ - "NotFilteredNotFound", - "NotFilteredWhiteList", - "NotFilteredError", - - "FilteredBlackList", - "FilteredSafeBrowsing", - "FilteredParental", - "FilteredInvalid", - "FilteredSafeSearch", - "FilteredBlockedService", - - "Rewrite", - } - if uint(r) >= uint(len(names)) { + if uint(r) >= uint(len(reasonNames)) { return "" } - return names[r] + return reasonNames[r] +} + +// GetConfig - get configuration +func (d *Dnsfilter) GetConfig() RequestFilteringSettings { + c := RequestFilteringSettings{} + // d.confLock.RLock() + c.SafeSearchEnabled = d.Config.SafeSearchEnabled + c.SafeBrowsingEnabled = d.Config.SafeBrowsingEnabled + c.ParentalEnabled = d.Config.ParentalEnabled + // d.confLock.RUnlock() + return c +} + +// WriteDiskConfig - write configuration +func (d *Dnsfilter) WriteDiskConfig(c *Config) { + *c = d.Config +} + +// SetFilters - set new filters (synchronously or asynchronously) +// When filters are set asynchronously, the old filters continue working until the new filters are ready. +// In this case the caller must ensure that the old filter files are intact. +func (d *Dnsfilter) SetFilters(filters map[int]string, async bool) error { + if async { + params := filtersInitializerParams{ + filters: filters, + } + + d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task + // remove all pending tasks + stop := false + for !stop { + select { + case <-d.filtersInitializerChan: + // + default: + stop = true + } + } + + d.filtersInitializerChan <- params + d.filtersInitializerLock.Unlock() + return nil + } + + err := d.initFiltering(filters) + if err != nil { + log.Error("Can't initialize filtering subsystem: %s", err) + return err + } + + return nil +} + +// Starts initializing new filters by signal from channel +func (d *Dnsfilter) filtersInitializer() { + for { + params := <-d.filtersInitializerChan + err := d.initFiltering(params.filters) + if err != nil { + log.Error("Can't initialize filtering subsystem: %s", err) + continue + } + } +} + +// Close - close the object +func (d *Dnsfilter) Close() { + if d != nil && d.transport != nil { + d.transport.CloseIdleConnections() + } + if d.rulesStorage != nil { + d.rulesStorage.Close() + } } type dnsFilterContext struct { @@ -294,6 +378,9 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering func (d *Dnsfilter) processRewrites(host string, qtype uint16) Result { var res Result + d.confLock.RLock() + defer d.confLock.RUnlock() + for _, r := range d.Rewrites { if r.Domain != host { continue @@ -704,17 +791,28 @@ func (d *Dnsfilter) initFiltering(filters map[int]string) error { listArray = append(listArray, list) } - var err error - d.rulesStorage, err = urlfilter.NewRuleStorage(listArray) + rulesStorage, err := urlfilter.NewRuleStorage(listArray) if err != nil { return fmt.Errorf("urlfilter.NewRuleStorage(): %s", err) } - d.filteringEngine = urlfilter.NewDNSEngine(d.rulesStorage) + filteringEngine := urlfilter.NewDNSEngine(rulesStorage) + + d.engineLock.Lock() + if d.rulesStorage != nil { + d.rulesStorage.Close() + } + d.rulesStorage = rulesStorage + d.filteringEngine = filteringEngine + d.engineLock.Unlock() + log.Debug("initialized filtering engine") + return nil } // matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { + d.engineLock.RLock() + defer d.engineLock.RUnlock() if d.filteringEngine == nil { return Result{}, nil } @@ -926,27 +1024,21 @@ func New(c *Config, filters map[int]string) *Dnsfilter { err := d.initFiltering(filters) if err != nil { log.Error("Can't initialize filtering subsystem: %s", err) - d.Destroy() + d.Close() return nil } } + d.filtersInitializerChan = make(chan filtersInitializerParams, 1) + go d.filtersInitializer() + + if d.Config.HTTPRegister != nil { // for tests + d.registerSecurityHandlers() + d.registerRewritesHandlers() + } return d } -// Destroy is optional if you want to tidy up goroutines without waiting for them to die off -// right now it closes idle HTTP connections if there are any -func (d *Dnsfilter) Destroy() { - if d != nil && d.transport != nil { - d.transport.CloseIdleConnections() - } - - if d.rulesStorage != nil { - d.rulesStorage.Close() - d.rulesStorage = nil - } -} - // // config manipulation helpers // diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 4c574b57..37255b78 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -108,7 +108,7 @@ func TestEtcHostsMatching(t *testing.T) { filters := make(map[int]string) filters[0] = text d := NewForTest(nil, filters) - defer d.Destroy() + defer d.Close() d.checkMatchIP(t, "google.com", addr, dns.TypeA) d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) @@ -133,7 +133,7 @@ func TestSafeBrowsing(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Destroy() + defer d.Close() gctx.stats.Safebrowsing.Requests = 0 d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru") @@ -158,7 +158,7 @@ func TestSafeBrowsing(t *testing.T) { func TestParallelSB(t *testing.T) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Destroy() + defer d.Close() t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { @@ -175,7 +175,7 @@ func TestParallelSB(t *testing.T) { // the only way to verify that custom server option is working is to point it at a server that does serve safebrowsing func TestSafeBrowsingCustomServerFail(t *testing.T) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Destroy() + defer d.Close() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // w.Write("Hello, client") fmt.Fprintln(w, "Hello, client") @@ -192,14 +192,14 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { func TestSafeSearch(t *testing.T) { d := NewForTest(nil, nil) - defer d.Destroy() + defer d.Close() _, ok := d.SafeSearchDomain("www.google.com") if ok { t.Errorf("Expected safesearch to error when disabled") } d = NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() val, ok := d.SafeSearchDomain("www.google.com") if !ok { t.Errorf("Expected safesearch to find result for www.google.com") @@ -211,7 +211,7 @@ func TestSafeSearch(t *testing.T) { func TestCheckHostSafeSearchYandex(t *testing.T) { d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() // Slice of yandex domains yandex := []string{"yAndeX.ru", "YANdex.COM", "yandex.ua", "yandex.by", "yandex.kz", "www.yandex.com"} @@ -231,7 +231,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) { d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() // Slice of google domains googleDomains := []string{"www.google.com", "www.google.im", "www.google.co.in", "www.google.iq", "www.google.is", "www.google.it", "www.google.je"} @@ -251,7 +251,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestSafeSearchCacheYandex(t *testing.T) { d := NewForTest(nil, nil) - defer d.Destroy() + defer d.Close() domain := "yandex.ru" var result Result @@ -267,7 +267,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { } d = NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() result, err = d.CheckHost(domain, dns.TypeA, &setts) if err != nil { @@ -293,7 +293,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) { d := NewForTest(nil, nil) - defer d.Destroy() + defer d.Close() domain := "www.google.ru" result, err := d.CheckHost(domain, dns.TypeA, &setts) if err != nil { @@ -304,7 +304,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } d = NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() // Let's lookup for safesearch domain safeDomain, ok := d.SafeSearchDomain(domain) @@ -352,7 +352,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { func TestParentalControl(t *testing.T) { d := NewForTest(&Config{ParentalEnabled: true}, nil) - defer d.Destroy() + defer d.Close() d.ParentalSensitivity = 3 d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com") @@ -435,7 +435,7 @@ func TestMatching(t *testing.T) { filters := make(map[int]string) filters[0] = test.rules d := NewForTest(nil, filters) - defer d.Destroy() + defer d.Close() ret, err := d.CheckHost(test.hostname, dns.TypeA, &setts) if err != nil { @@ -472,7 +472,7 @@ func TestClientSettings(t *testing.T) { filters := make(map[int]string) filters[0] = "||example.org^\n" d := NewForTest(&Config{ParentalEnabled: true, SafeBrowsingEnabled: false}, filters) - defer d.Destroy() + defer d.Close() d.ParentalSensitivity = 3 // no client settings: @@ -529,7 +529,7 @@ func TestClientSettings(t *testing.T) { func BenchmarkSafeBrowsing(b *testing.B) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Destroy() + defer d.Close() for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" ret, err := d.CheckHost(hostname, dns.TypeA, &setts) @@ -544,7 +544,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Destroy() + defer d.Close() b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" @@ -561,7 +561,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) { d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") if !ok { @@ -575,7 +575,7 @@ func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearchParallel(b *testing.B) { d := NewForTest(&Config{SafeSearchEnabled: true}, nil) - defer d.Destroy() + defer d.Close() b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") diff --git a/dnsfilter/rewrites.go b/dnsfilter/rewrites.go new file mode 100644 index 00000000..6cc18784 --- /dev/null +++ b/dnsfilter/rewrites.go @@ -0,0 +1,93 @@ +// DNS Rewrites + +package dnsfilter + +import ( + "encoding/json" + "net/http" + + "github.com/AdguardTeam/golibs/log" +) + +type rewriteEntryJSON struct { + Domain string `json:"domain"` + Answer string `json:"answer"` +} + +func (d *Dnsfilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { + + arr := []*rewriteEntryJSON{} + + d.confLock.Lock() + for _, ent := range d.Config.Rewrites { + jsent := rewriteEntryJSON{ + Domain: ent.Domain, + Answer: ent.Answer, + } + arr = append(arr, &jsent) + } + d.confLock.Unlock() + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(arr) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + return + } +} + +func (d *Dnsfilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { + + jsent := rewriteEntryJSON{} + err := json.NewDecoder(r.Body).Decode(&jsent) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return + } + + ent := RewriteEntry{ + Domain: jsent.Domain, + Answer: jsent.Answer, + } + d.confLock.Lock() + d.Config.Rewrites = append(d.Config.Rewrites, ent) + d.confLock.Unlock() + log.Debug("Rewrites: added element: %s -> %s [%d]", + ent.Domain, ent.Answer, len(d.Config.Rewrites)) + + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) { + + jsent := rewriteEntryJSON{} + err := json.NewDecoder(r.Body).Decode(&jsent) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return + } + + entDel := RewriteEntry{ + Domain: jsent.Domain, + Answer: jsent.Answer, + } + arr := []RewriteEntry{} + d.confLock.Lock() + for _, ent := range d.Config.Rewrites { + if ent == entDel { + log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer) + continue + } + arr = append(arr, ent) + } + d.Config.Rewrites = arr + d.confLock.Unlock() + + d.Config.ConfigModified() +} + +func (d *Dnsfilter) registerRewritesHandlers() { + d.Config.HTTPRegister("GET", "/control/rewrite/list", d.handleRewriteList) + d.Config.HTTPRegister("POST", "/control/rewrite/add", d.handleRewriteAdd) + d.Config.HTTPRegister("POST", "/control/rewrite/delete", d.handleRewriteDelete) +} diff --git a/dnsfilter/security.go b/dnsfilter/security.go new file mode 100644 index 00000000..c4ce32de --- /dev/null +++ b/dnsfilter/security.go @@ -0,0 +1,179 @@ +// Parental Control, Safe Browsing, Safe Search + +package dnsfilter + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/AdguardTeam/golibs/log" +) + +func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text) + http.Error(w, text, code) +} + +func (d *Dnsfilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeBrowsingEnabled = true + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeBrowsingEnabled = false + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "enabled": d.Config.SafeBrowsingEnabled, + } + jsonVal, err := json.Marshal(data) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return + } +} + +func parseParametersFromBody(r io.Reader) (map[string]string, error) { + parameters := map[string]string{} + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 { + // skip empty lines + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + return parameters, errors.New("Got invalid request body") + } + parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } + + return parameters, nil +} + +func (d *Dnsfilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { + parameters, err := parseParametersFromBody(r.Body) + if err != nil { + httpError(r, w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) + return + } + + sensitivity, ok := parameters["sensitivity"] + if !ok { + http.Error(w, "Sensitivity parameter was not specified", 400) + return + } + + switch sensitivity { + case "3": + break + case "EARLY_CHILDHOOD": + sensitivity = "3" + case "10": + break + case "YOUNG": + sensitivity = "10" + case "13": + break + case "TEEN": + sensitivity = "13" + case "17": + break + case "MATURE": + sensitivity = "17" + default: + http.Error(w, "Sensitivity must be set to valid value", 400) + return + } + i, err := strconv.Atoi(sensitivity) + if err != nil { + http.Error(w, "Sensitivity must be set to valid value", 400) + return + } + d.Config.ParentalSensitivity = i + d.Config.ParentalEnabled = true + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) { + d.Config.ParentalEnabled = false + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "enabled": d.Config.ParentalEnabled, + } + if d.Config.ParentalEnabled { + data["sensitivity"] = d.Config.ParentalSensitivity + } + jsonVal, err := json.Marshal(data) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return + } +} + +func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeSearchEnabled = true + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeSearchEnabled = false + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "enabled": d.Config.SafeSearchEnabled, + } + jsonVal, err := json.Marshal(data) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return + } +} + +func (d *Dnsfilter) registerSecurityHandlers() { + d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) + d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) + d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus) + d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable) + d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable) + d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus) + d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable) + d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable) + d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus) +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index d28889b3..415b1cef 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -3,7 +3,6 @@ package dnsforward import ( "crypto/tls" "errors" - "fmt" "net" "net/http" "strings" @@ -44,12 +43,6 @@ type Server struct { queryLog querylog.QueryLog // Query log instance stats stats.Stats - // How many times the server was started - // While creating a dnsfilter object, - // we use this value to set s.dnsFilter property only with the most recent settings. - startCounter uint32 - dnsfilterCreatorChan chan dnsfilterCreatorParams - AllowedClients map[string]bool // IP addresses of whitelist clients DisallowedClients map[string]bool // IP addresses of clients that should be blocked AllowedClientsIPNet []net.IPNet // CIDRs of whitelist clients @@ -60,15 +53,11 @@ type Server struct { conf ServerConfig } -type dnsfilterCreatorParams struct { - conf dnsfilter.Config - filters map[int]string -} - // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once -func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server { +func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog querylog.QueryLog) *Server { s := &Server{} + s.dnsFilter = dnsFilter s.stats = stats s.queryLog = queryLog return s @@ -76,6 +65,7 @@ func NewServer(stats stats.Stats, queryLog querylog.QueryLog) *Server { func (s *Server) Close() { s.Lock() + s.dnsFilter = nil s.stats = nil s.queryLog = nil s.Unlock() @@ -84,11 +74,8 @@ func (s *Server) Close() { // FilteringConfig represents the DNS filtering configuration of AdGuard Home // The zero FilteringConfig is empty and ready for use. type FilteringConfig struct { - // Create dnsfilter asynchronously. - // Requests won't be filtered until dnsfilter is created. - // If "restart" command is received while we're creating an old dnsfilter object, - // we delay creation of the new object until the old one is created. - AsyncStartup bool `yaml:"-"` + // Filtering callback function + FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists @@ -116,8 +103,9 @@ type FilteringConfig struct { // Per-client settings can override this configuration. BlockedServices []string `yaml:"blocked_services"` - CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) - dnsfilter.Config `yaml:",inline"` + CacheSize uint `yaml:"cache_size"` // DNS cache size (in bytes) + + DnsfilterConf dnsfilter.Config `yaml:",inline"` } // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS @@ -140,7 +128,6 @@ type ServerConfig struct { TCPListenAddr *net.TCPAddr // TCP listen address Upstreams []upstream.Upstream // Configured upstreams DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams - Filters []dnsfilter.Filter // A list of filters to use OnDNSRequest func(d *proxy.DNSContext) FilteringConfig @@ -204,13 +191,18 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { - if s.dnsFilter != nil || s.dnsProxy != nil { + if s.dnsProxy != nil { return errors.New("DNS server is already started") } - err := s.initDNSFilter(config) - if err != nil { - return err + if config != nil { + s.conf = *config + } + if len(s.conf.ParentalBlockHost) == 0 { + s.conf.ParentalBlockHost = parentalBlockHost + } + if len(s.conf.SafeBrowsingBlockHost) == 0 { + s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost } proxyConfig := proxy.Config{ @@ -228,7 +220,7 @@ func (s *Server) startInternal(config *ServerConfig) error { AllServers: s.conf.AllServers, } - err = processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) + err := processIPCIDRArray(&s.AllowedClients, &s.AllowedClientsIPNet, s.conf.AllowedClients) if err != nil { return err } @@ -269,97 +261,6 @@ func (s *Server) startInternal(config *ServerConfig) error { return s.dnsProxy.Start() } -// Initializes the DNS filter -func (s *Server) initDNSFilter(config *ServerConfig) error { - if config != nil { - s.conf = *config - } - - var filters map[int]string - filters = nil - if s.conf.FilteringEnabled { - filters = make(map[int]string) - for _, f := range s.conf.Filters { - if f.ID == 0 { - filters[int(f.ID)] = string(f.Data) - } else { - filters[int(f.ID)] = f.FilePath - } - } - } - - if len(s.conf.ParentalBlockHost) == 0 { - s.conf.ParentalBlockHost = parentalBlockHost - } - if len(s.conf.SafeBrowsingBlockHost) == 0 { - s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost - } - - if s.conf.AsyncStartup { - params := dnsfilterCreatorParams{ - conf: s.conf.Config, - filters: filters, - } - s.startCounter++ - if s.startCounter == 1 { - s.dnsfilterCreatorChan = make(chan dnsfilterCreatorParams, 1) - go s.dnsfilterCreator() - } - - // remove all pending tasks - stop := false - for !stop { - select { - case <-s.dnsfilterCreatorChan: - // - default: - stop = true - } - } - - s.dnsfilterCreatorChan <- params - } else { - log.Debug("creating dnsfilter...") - f := dnsfilter.New(&s.conf.Config, filters) - if f == nil { - return fmt.Errorf("could not initialize dnsfilter") - } - log.Debug("created dnsfilter") - s.dnsFilter = f - } - return nil -} - -func (s *Server) dnsfilterCreator() { - for { - params := <-s.dnsfilterCreatorChan - - s.Lock() - counter := s.startCounter - s.Unlock() - - log.Debug("creating dnsfilter...") - f := dnsfilter.New(¶ms.conf, params.filters) - if f == nil { - log.Error("could not initialize dnsfilter") - continue - } - - set := false - s.Lock() - if counter == s.startCounter { - s.dnsFilter = f - set = true - } - s.Unlock() - if set { - log.Debug("created and activated dnsfilter") - } else { - log.Debug("created dnsfilter") - } - } -} - // Stop stops the DNS server func (s *Server) Stop() error { s.Lock() @@ -377,11 +278,6 @@ func (s *Server) stopInternal() error { } } - if s.dnsFilter != nil { - s.dnsFilter.Destroy() - s.dnsFilter = nil - } - return nil } @@ -607,33 +503,24 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { - var res dnsfilter.Result - req := d.Req - host := strings.TrimSuffix(req.Question[0].Name, ".") - - dnsFilter := s.dnsFilter - if !s.conf.ProtectionEnabled || s.dnsFilter == nil { return &dnsfilter.Result{}, nil } - var err error - clientAddr := "" if d.Addr != nil { clientAddr, _, _ = net.SplitHostPort(d.Addr.String()) } - var setts dnsfilter.RequestFilteringSettings + setts := s.dnsFilter.GetConfig() setts.FilteringEnabled = true - setts.SafeSearchEnabled = s.conf.SafeSearchEnabled - setts.SafeBrowsingEnabled = s.conf.SafeBrowsingEnabled - setts.ParentalEnabled = s.conf.ParentalEnabled if s.conf.FilterHandler != nil { s.conf.FilterHandler(clientAddr, &setts) } - res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) + req := d.Req + host := strings.TrimSuffix(req.Question[0].Name, ".") + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 94e72d67..2568ef7b 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -148,7 +148,6 @@ func TestServerRace(t *testing.T) { func TestSafeSearch(t *testing.T) { s := createTestServer(t) - s.conf.SafeSearchEnabled = true err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -376,23 +375,24 @@ func TestBlockedBySafeBrowsing(t *testing.T) { } func createTestServer(t *testing.T) *Server { - s := NewServer(nil, nil) + rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n" + filters := map[int]string{} + filters[0] = rules + c := dnsfilter.Config{} + c.SafeBrowsingEnabled = true + c.SafeBrowsingCacheSize = 1000 + c.SafeSearchEnabled = true + c.SafeSearchCacheSize = 1000 + c.ParentalCacheSize = 1000 + c.CacheTime = 30 + + f := dnsfilter.New(&c, filters) + s := NewServer(f, nil, nil) s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.FilteringConfig.FilteringEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true - s.conf.FilteringConfig.SafeBrowsingEnabled = true - s.conf.Filters = make([]dnsfilter.Filter, 0) - - s.conf.SafeBrowsingCacheSize = 1000 - s.conf.SafeSearchCacheSize = 1000 - s.conf.ParentalCacheSize = 1000 - s.conf.CacheTime = 30 - - rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n" - filter := dnsfilter.Filter{ID: 0, Data: []byte(rules)} - s.conf.Filters = append(s.conf.Filters, filter) return s } diff --git a/home/config.go b/home/config.go index 726eda90..fcd5e3dc 100644 --- a/home/config.go +++ b/home/config.go @@ -10,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" @@ -71,7 +72,6 @@ type configuration struct { client *http.Client stats stats.Stats // statistics module queryLog querylog.QueryLog // query log module - filteringStarted bool // TRUE if filtering module is started auth *Auth // HTTP authentication module // cached version.json to avoid hammering github.io for each page reload @@ -79,6 +79,7 @@ type configuration struct { versionCheckLastTime time.Time dnsctx dnsContext + dnsFilter *dnsfilter.Dnsfilter dnsServer *dnsforward.Server dhcpServer dhcpd.Server httpServer *http.Server @@ -217,10 +218,10 @@ func initConfig() { } config.DNS.CacheSize = 4 * 1024 * 1024 - config.DNS.SafeBrowsingCacheSize = 1 * 1024 * 1024 - config.DNS.SafeSearchCacheSize = 1 * 1024 * 1024 - config.DNS.ParentalCacheSize = 1 * 1024 * 1024 - config.DNS.CacheTime = 30 + config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024 + config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024 + config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024 + config.DNS.DnsfilterConf.CacheTime = 30 config.Filters = defaultFilters() } @@ -367,6 +368,12 @@ func (c *configuration) write() error { config.DNS.QueryLogInterval = dc.Interval } + if config.dnsFilter != nil { + c := dnsfilter.Config{} + config.dnsFilter.WriteDiskConfig(&c) + config.DNS.DnsfilterConf = c + } + configFile := config.getConfigFilename() log.Debug("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) diff --git a/home/control.go b/home/control.go index d18c2a85..1f2eb1fa 100644 --- a/home/control.go +++ b/home/control.go @@ -377,142 +377,6 @@ func checkDNS(input string, bootstrap []string) error { return nil } -// ------------ -// safebrowsing -// ------------ - -func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - config.DNS.SafeBrowsingEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - config.DNS.SafeBrowsingEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { - data := map[string]interface{}{ - "enabled": config.DNS.SafeBrowsingEnabled, - } - jsonVal, err := json.Marshal(data) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return - } -} - -// -------- -// parental -// -------- -func handleParentalEnable(w http.ResponseWriter, r *http.Request) { - parameters, err := parseParametersFromBody(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "failed to parse parameters from body: %s", err) - return - } - - sensitivity, ok := parameters["sensitivity"] - if !ok { - http.Error(w, "Sensitivity parameter was not specified", 400) - return - } - - switch sensitivity { - case "3": - break - case "EARLY_CHILDHOOD": - sensitivity = "3" - case "10": - break - case "YOUNG": - sensitivity = "10" - case "13": - break - case "TEEN": - sensitivity = "13" - case "17": - break - case "MATURE": - sensitivity = "17" - default: - http.Error(w, "Sensitivity must be set to valid value", 400) - return - } - i, err := strconv.Atoi(sensitivity) - if err != nil { - http.Error(w, "Sensitivity must be set to valid value", 400) - return - } - config.DNS.ParentalSensitivity = i - config.DNS.ParentalEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleParentalDisable(w http.ResponseWriter, r *http.Request) { - config.DNS.ParentalEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleParentalStatus(w http.ResponseWriter, r *http.Request) { - data := map[string]interface{}{ - "enabled": config.DNS.ParentalEnabled, - } - if config.DNS.ParentalEnabled { - data["sensitivity"] = config.DNS.ParentalSensitivity - } - jsonVal, err := json.Marshal(data) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return - } -} - -// ------------ -// safebrowsing -// ------------ - -func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - config.DNS.SafeSearchEnabled = true - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - config.DNS.SafeSearchEnabled = false - httpUpdateConfigReloadDNSReturnOK(w, r) -} - -func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { - data := map[string]interface{}{ - "enabled": config.DNS.SafeSearchEnabled, - } - jsonVal, err := json.Marshal(data) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return - } -} - // -------------- // DNS-over-HTTPS // -------------- @@ -543,15 +407,6 @@ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) httpRegister(http.MethodPost, "/control/update", handleUpdate) - httpRegister(http.MethodPost, "/control/safebrowsing/enable", handleSafeBrowsingEnable) - httpRegister(http.MethodPost, "/control/safebrowsing/disable", handleSafeBrowsingDisable) - httpRegister(http.MethodGet, "/control/safebrowsing/status", handleSafeBrowsingStatus) - httpRegister(http.MethodPost, "/control/parental/enable", handleParentalEnable) - httpRegister(http.MethodPost, "/control/parental/disable", handleParentalDisable) - httpRegister(http.MethodGet, "/control/parental/status", handleParentalStatus) - httpRegister(http.MethodPost, "/control/safesearch/enable", handleSafeSearchEnable) - httpRegister(http.MethodPost, "/control/safesearch/disable", handleSafeSearchDisable) - httpRegister(http.MethodGet, "/control/safesearch/status", handleSafeSearchStatus) httpRegister(http.MethodGet, "/control/dhcp/status", handleDHCPStatus) httpRegister(http.MethodGet, "/control/dhcp/interfaces", handleDHCPInterfaces) httpRegister(http.MethodPost, "/control/dhcp/set_config", handleDHCPSetConfig) @@ -565,7 +420,6 @@ func registerControlHandlers() { RegisterFilteringHandlers() RegisterTLSHandlers() RegisterClientsHandlers() - registerRewritesHandlers() RegisterBlockedServicesHandlers() RegisterAuthHandlers() diff --git a/home/control_filtering.go b/home/control_filtering.go index 03699953..453e0ec4 100644 --- a/home/control_filtering.go +++ b/home/control_filtering.go @@ -86,17 +86,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - err = writeAllConfigs() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write config file: %s", err) - return - } - - err = reconfigureDNSServer() - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't reconfigure the DNS server: %s", err) - return - } + onConfigModified() + enableFilters(true) _, err = fmt.Fprintf(w, "OK %d rules\n", f.RulesCount) if err != nil { @@ -121,32 +112,28 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { return } - // Stop DNS server: - // we close urlfilter object which in turn closes file descriptors to filter files. - // Otherwise, Windows won't allow us to remove the file which is being currently used. - _ = config.dnsServer.Stop() - // go through each element and delete if url matches config.Lock() - newFilters := config.Filters[:0] + newFilters := []filter{} for _, filter := range config.Filters { if filter.URL != req.URL { newFilters = append(newFilters, filter) } else { - // Remove the filter file - err := os.Remove(filter.Path()) - if err != nil && !os.IsNotExist(err) { - config.Unlock() - httpError(w, http.StatusInternalServerError, "Couldn't remove the filter file: %s", err) - return + err := os.Rename(filter.Path(), filter.Path()+".old") + if err != nil { + log.Error("os.Rename: %s: %s", filter.Path(), err) } - log.Debug("os.Remove(%s)", filter.Path()) } } // Update the configuration after removing filter files config.Filters = newFilters config.Unlock() - httpUpdateConfigReloadDNSReturnOK(w, r) + + onConfigModified() + enableFilters(true) + + // Note: the old files "filter.txt.old" aren't deleted - it's not really necessary, + // but will require the additional code to run after enableFilters() is finished: i.e. complicated } type filterURLJSON struct { @@ -173,7 +160,8 @@ func handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { return } - httpUpdateConfigReloadDNSReturnOK(w, r) + onConfigModified() + enableFilters(true) } func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { @@ -184,12 +172,13 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } config.UserRules = strings.Split(string(body), "\n") - httpUpdateConfigReloadDNSReturnOK(w, r) + _ = writeAllConfigs() + enableFilters(true) } func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { - updated := refreshFiltersIfNecessary(true) - fmt.Fprintf(w, "OK %d filters updated\n", updated) + beginRefreshFilters() + fmt.Fprintf(w, "OK 0 filters updated\n") } type filterJSON struct { @@ -260,9 +249,8 @@ func handleFilteringConfig(w http.ResponseWriter, r *http.Request) { config.DNS.FilteringEnabled = req.Enabled config.DNS.FiltersUpdateIntervalHours = req.Interval - httpUpdateConfigReloadDNSReturnOK(w, r) - - returnOK(w) + onConfigModified() + enableFilters(true) } // RegisterFilteringHandlers - register handlers diff --git a/home/dns.go b/home/dns.go index 64c9efe4..8e755e06 100644 --- a/home/dns.go +++ b/home/dns.go @@ -55,7 +55,18 @@ func initDNSServer() { HTTPRegister: httpRegister, } config.queryLog = querylog.New(conf) - config.dnsServer = dnsforward.NewServer(config.stats, config.queryLog) + + filterConf := config.DNS.DnsfilterConf + bindhost := config.DNS.BindHost + if config.DNS.BindHost == "0.0.0.0" { + bindhost = "127.0.0.1" + } + filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) + filterConf.ConfigModified = onConfigModified + filterConf.HTTPRegister = httpRegister + config.dnsFilter = dnsfilter.New(&filterConf, nil) + + config.dnsServer = dnsforward.NewServer(config.dnsFilter, config.stats, config.queryLog) sessFilename := filepath.Join(baseDir, "sessions.db") config.auth = InitAuth(sessFilename, config.Users) @@ -159,34 +170,11 @@ func onDNSRequest(d *proxy.DNSContext) { } func generateServerConfig() (dnsforward.ServerConfig, error) { - filters := []dnsfilter.Filter{} - userFilter := userFilter() - filters = append(filters, dnsfilter.Filter{ - ID: userFilter.ID, - Data: userFilter.Data, - }) - for _, filter := range config.Filters { - if !filter.Enabled { - continue - } - filters = append(filters, dnsfilter.Filter{ - ID: filter.ID, - FilePath: filter.Path(), - }) - } - newconfig := dnsforward.ServerConfig{ UDPListenAddr: &net.UDPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: net.ParseIP(config.DNS.BindHost), Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, - Filters: filters, } - newconfig.AsyncStartup = true - bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" - } - newconfig.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) if config.TLS.Enabled { newconfig.TLSConfig = config.TLS.TLSConfig @@ -242,20 +230,18 @@ func startDNSServer() error { return fmt.Errorf("unable to start forwarding DNS server: Already running") } + enableFilters(false) + newconfig, err := generateServerConfig() if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } + err = config.dnsServer.Start(&newconfig) if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } - if !config.filteringStarted { - config.filteringStarted = true - startRefreshFilters() - } - return nil } @@ -285,6 +271,9 @@ func stopDNSServer() error { // DNS forward module must be closed BEFORE stats or queryLog because it depends on them config.dnsServer.Close() + config.dnsFilter.Close() + config.dnsFilter = nil + config.stats.Close() config.stats = nil diff --git a/home/dns_rewrites.go b/home/dns_rewrites.go deleted file mode 100644 index e58c50d7..00000000 --- a/home/dns_rewrites.go +++ /dev/null @@ -1,104 +0,0 @@ -package home - -import ( - "encoding/json" - "net/http" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/AdguardTeam/golibs/log" -) - -type rewriteEntryJSON struct { - Domain string `json:"domain"` - Answer string `json:"answer"` -} - -func handleRewriteList(w http.ResponseWriter, r *http.Request) { - - arr := []*rewriteEntryJSON{} - - config.RLock() - for _, ent := range config.DNS.Rewrites { - jsent := rewriteEntryJSON{ - Domain: ent.Domain, - Answer: ent.Answer, - } - arr = append(arr, &jsent) - } - config.RUnlock() - - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(arr) - if err != nil { - httpError(w, http.StatusInternalServerError, "json.Encode: %s", err) - return - } -} - -func handleRewriteAdd(w http.ResponseWriter, r *http.Request) { - - jsent := rewriteEntryJSON{} - err := json.NewDecoder(r.Body).Decode(&jsent) - if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) - return - } - - ent := dnsfilter.RewriteEntry{ - Domain: jsent.Domain, - Answer: jsent.Answer, - } - config.Lock() - config.DNS.Rewrites = append(config.DNS.Rewrites, ent) - config.Unlock() - log.Debug("Rewrites: added element: %s -> %s [%d]", - ent.Domain, ent.Answer, len(config.DNS.Rewrites)) - - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - returnOK(w) -} - -func handleRewriteDelete(w http.ResponseWriter, r *http.Request) { - - jsent := rewriteEntryJSON{} - err := json.NewDecoder(r.Body).Decode(&jsent) - if err != nil { - httpError(w, http.StatusBadRequest, "json.Decode: %s", err) - return - } - - entDel := dnsfilter.RewriteEntry{ - Domain: jsent.Domain, - Answer: jsent.Answer, - } - arr := []dnsfilter.RewriteEntry{} - config.Lock() - for _, ent := range config.DNS.Rewrites { - if ent == entDel { - log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer) - continue - } - arr = append(arr, ent) - } - config.DNS.Rewrites = arr - config.Unlock() - - err = writeAllConfigsAndReloadDNS() - if err != nil { - httpError(w, http.StatusBadRequest, "%s", err) - return - } - - returnOK(w) -} - -func registerRewritesHandlers() { - httpRegister(http.MethodGet, "/control/rewrite/list", handleRewriteList) - httpRegister(http.MethodPost, "/control/rewrite/add", handleRewriteAdd) - httpRegister(http.MethodPost, "/control/rewrite/delete", handleRewriteDelete) -} diff --git a/home/filter.go b/home/filter.go index be60e2a3..38425301 100644 --- a/home/filter.go +++ b/home/filter.go @@ -19,18 +19,13 @@ import ( var ( nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) + forceRefresh bool ) func initFiltering() { loadFilters() deduplicateFilters() updateUniqueFilterID(config.Filters) -} - -func startRefreshFilters() { - go func() { - _ = refreshFiltersIfNecessary(false) - }() go periodicallyRefreshFilters() } @@ -180,32 +175,43 @@ func assignUniqueFilterID() int64 { // Sets up a timer that will be checking for filters updates periodically func periodicallyRefreshFilters() { + nextRefresh := int64(0) for { - time.Sleep(1 * time.Hour) - if config.DNS.FiltersUpdateIntervalHours == 0 { - continue + if forceRefresh { + _ = refreshFiltersIfNecessary(true) + forceRefresh = false } - refreshFiltersIfNecessary(false) + if config.DNS.FiltersUpdateIntervalHours != 0 && nextRefresh <= time.Now().Unix() { + _ = refreshFiltersIfNecessary(false) + nextRefresh = time.Now().Add(1 * time.Hour).Unix() + } + time.Sleep(1 * time.Second) } } +// Schedule the procedure to refresh filters +func beginRefreshFilters() { + forceRefresh = true + log.Debug("Filters: schedule update") +} + // Checks filters updates if necessary // If force is true, it ignores the filter.LastUpdated field value // // Algorithm: // . Get the list of filters to be updated // . For each filter run the download and checksum check operation -// . Stop server // . For each filter: // . If filter data hasn't changed, just set new update time on file -// . If filter data has changed, save it on disk -// . Apply changes to the current configuration -// . Start server +// . If filter data has changed: rename the old file, store the new data on disk +// . Pass new filters to dnsfilter object func refreshFiltersIfNecessary(force bool) int { var updateFilters []filter var updateFlags []bool // 'true' if filter data has changed + log.Debug("Filters: updating...") + now := time.Now() config.RLock() for i := range config.Filters { @@ -229,7 +235,6 @@ func refreshFiltersIfNecessary(force bool) int { } config.RUnlock() - updateCount := 0 for i := range updateFilters { uf := &updateFilters[i] updated, err := uf.update() @@ -239,24 +244,14 @@ func refreshFiltersIfNecessary(force bool) int { continue } uf.LastUpdated = now - if updated { - updateCount++ - } } - stopped := false - if updateCount != 0 { - _ = config.dnsServer.Stop() - stopped = true - } - - updateCount = 0 + updateCount := 0 for i := range updateFilters { uf := &updateFilters[i] updated := updateFlags[i] if updated { - // Saving it to the filters dir now - err := uf.save() + err := uf.saveAndBackupOld() if err != nil { log.Printf("Failed to save the updated filter %d: %s", uf.ID, err) continue @@ -290,12 +285,20 @@ func refreshFiltersIfNecessary(force bool) int { config.Unlock() } - if stopped { - err := reconfigureDNSServer() - if err != nil { - log.Error("cannot reconfigure DNS server with the new filters: %s", err) + if updateCount != 0 { + enableFilters(false) + + for i := range updateFilters { + uf := &updateFilters[i] + updated := updateFlags[i] + if !updated { + continue + } + _ = os.Remove(uf.Path() + ".old") } } + + log.Debug("Filters: update finished") return updateCount } @@ -413,6 +416,12 @@ func (filter *filter) save() error { return err } +func (filter *filter) saveAndBackupOld() error { + filterFilePath := filter.Path() + _ = os.Rename(filterFilePath, filterFilePath+".old") + return filter.save() +} + // loads filter contents from the file in dataDir func (filter *filter) load() error { filterFilePath := filter.Path() @@ -467,3 +476,23 @@ func (filter *filter) LastTimeUpdated() time.Time { // filter file modified time return s.ModTime() } + +func enableFilters(async bool) { + var filters map[int]string + if config.DNS.FilteringConfig.FilteringEnabled { + // convert array of filters + filters = make(map[int]string) + + userFilter := userFilter() + filters[int(userFilter.ID)] = string(userFilter.Data) + + for _, filter := range config.Filters { + if !filter.Enabled { + continue + } + filters[int(filter.ID)] = filter.Path() + } + } + + _ = config.dnsFilter.SetFilters(filters, async) +} diff --git a/home/helpers.go b/home/helpers.go index 756d6b97..6b0f01ed 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -1,12 +1,10 @@ package home import ( - "bufio" "bytes" "context" "errors" "fmt" - "io" "net" "net/http" "net/url" @@ -155,29 +153,6 @@ func postInstallHandler(handler http.Handler) http.Handler { return &postInstallHandlerStruct{handler} } -// ------------------------------------------------- -// helper functions for parsing parameters from body -// ------------------------------------------------- -func parseParametersFromBody(r io.Reader) (map[string]string, error) { - parameters := map[string]string{} - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - if len(line) == 0 { - // skip empty lines - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - return parameters, errors.New("Got invalid request body") - } - parameters[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) - } - - return parameters, nil -} - // ------------------ // network interfaces // ------------------ diff --git a/home/home.go b/home/home.go index 80f9e860..afbb3002 100644 --- a/home/home.go +++ b/home/home.go @@ -143,11 +143,12 @@ func run(args options) { } initDNSServer() - - err = startDNSServer() - if err != nil { - log.Fatal(err) - } + go func() { + err = startDNSServer() + if err != nil { + log.Fatal(err) + } + }() err = startDHCPServer() if err != nil {