diff --git a/internal/aghhttp/aghhttp.go b/internal/aghhttp/aghhttp.go index 57a1c868..23f9f5d3 100644 --- a/internal/aghhttp/aghhttp.go +++ b/internal/aghhttp/aghhttp.go @@ -9,6 +9,12 @@ import ( "github.com/AdguardTeam/golibs/log" ) +// HTTP scheme constants. +const ( + SchemeHTTP = "http" + SchemeHTTPS = "https" +) + // RegisterFunc is the function that sets the handler to handle the URL for the // method. // diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index b11d1f65..7d1ae199 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -67,10 +67,11 @@ func createTestServer( ID: 0, Data: []byte(rules), }} - f := filtering.New(filterConf, filters) + f, err := filtering.New(filterConf, filters) + require.NoError(t, err) + f.SetEnabled(true) - var err error s, err = NewServer(DNSCreateParams{ DHCPServer: testDHCP, DNSFilter: f, @@ -774,7 +775,9 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - f := filtering.New(&filtering.Config{}, filters) + f, err := filtering.New(&filtering.Config{}, filters) + require.NoError(t, err) + s, err := NewServer(DNSCreateParams{ DHCPServer: testDHCP, DNSFilter: f, @@ -906,7 +909,9 @@ func TestRewrite(t *testing.T) { Type: dns.TypeCNAME, }}, } - f := filtering.New(c, nil) + f, err := filtering.New(c, nil) + require.NoError(t, err) + f.SetEnabled(true) s, err := NewServer(DNSCreateParams{ @@ -1021,19 +1026,14 @@ var testDHCP = &dhcpd.MockInterface{ OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") }, } -// func (*testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) { -// return []*dhcpd.Lease{{ -// IP: net.IP{192, 168, 12, 34}, -// HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, -// Hostname: "myhost", -// }} -// } - func TestPTRResponseFromDHCPLeases(t *testing.T) { const localDomain = "lan" + flt, err := filtering.New(&filtering.Config{}, nil) + require.NoError(t, err) + s, err := NewServer(DNSCreateParams{ - DNSFilter: filtering.New(&filtering.Config{}, nil), + DNSFilter: flt, DHCPServer: testDHCP, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), LocalDomain: localDomain, @@ -1100,9 +1100,11 @@ func TestPTRResponseFromHosts(t *testing.T) { assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter)) }) - flt := filtering.New(&filtering.Config{ + flt, err := filtering.New(&filtering.Config{ EtcHosts: hc, }, nil) + require.NoError(t, err) + flt.SetEnabled(true) var s *Server diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 00c04252..7fa0985a 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -35,7 +35,8 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { ID: 0, Data: []byte(rules), }} - f := filtering.New(&filtering.Config{}, filters) + f, err := filtering.New(&filtering.Config{}, filters) + require.NoError(t, err) f.SetEnabled(true) s, err := NewServer(DNSCreateParams{ diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 08866100..489def36 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -421,31 +421,34 @@ func initBlockedServices() { } // BlockedSvcKnown - return TRUE if a blocked service name is known -func BlockedSvcKnown(s string) bool { - _, ok := serviceRules[s] +func BlockedSvcKnown(s string) (ok bool) { + _, ok = serviceRules[s] + return ok } // ApplyBlockedServices - set blocked services settings for this DNS request -func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) { +func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) { setts.ServicesRules = []ServiceEntry{} - if global { + if list == nil { d.confLock.RLock() defer d.confLock.RUnlock() + list = d.Config.BlockedServices } + for _, name := range list { rules, ok := serviceRules[name] - if !ok { log.Error("unknown service name: %s", name) + continue } - s := ServiceEntry{} - s.Name = name - s.Rules = rules - setts.ServicesRules = append(setts.ServicesRules, s) + setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{ + Name: name, + Rules: rules, + }) } } @@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ d.ConfigModified() } - -// registerBlockedServicesHandlers - register HTTP handlers -func (d *DNSFilter) registerBlockedServicesHandlers() { - d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices) - d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList) - d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet) -} diff --git a/internal/home/controlfiltering.go b/internal/filtering/controlfiltering.go similarity index 61% rename from internal/home/controlfiltering.go rename to internal/filtering/controlfiltering.go index a4c8651a..1cce8ded 100644 --- a/internal/home/controlfiltering.go +++ b/internal/filtering/controlfiltering.go @@ -1,4 +1,4 @@ -package home +package filtering import ( "encoding/json" @@ -34,7 +34,7 @@ func validateFilterURL(urlStr string) (err error) { return fmt.Errorf("checking filter url: %w", err) } - if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS { + if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS { return fmt.Errorf("checking filter url: invalid scheme %q", s) } @@ -47,7 +47,7 @@ type filterAddJSON struct { Whitelist bool `json:"whitelist"` } -func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { fj := filterAddJSON{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { @@ -65,14 +65,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request } // Check for duplicates - if filterExists(fj.URL) { + if d.filterExists(fj.URL) { aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) return } // Set necessary properties - filt := filter{ + filt := FilterYAML{ Enabled: true, URL: fj.URL, Name: fj.Name, @@ -81,7 +81,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request filt.ID = assignUniqueFilterID() // Download the filter contents - ok, err := f.update(&filt) + ok, err := d.update(&filt) if err != nil { aghhttp.Error( r, @@ -109,14 +109,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request // URL is assumed valid so append it to filters, update config, write new // file and reload it to engines. - if !filterAdd(filt) { + if !d.filterAdd(filt) { aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) return } - onConfigModified() - enableFilters(true) + d.ConfigModified() + d.EnableFilters(true) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) if err != nil { @@ -124,7 +124,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request } } -func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { type request struct { URL string `json:"url"` Whitelist bool `json:"whitelist"` @@ -138,23 +138,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ return } - config.Lock() - filters := &config.Filters + d.filtersMu.Lock() + filters := &d.Filters if req.Whitelist { - filters = &config.WhitelistFilters + filters = &d.WhitelistFilters } - var deleted filter - var newFilters []filter - for _, f := range *filters { - if f.URL != req.URL { - newFilters = append(newFilters, f) + var deleted FilterYAML + var newFilters []FilterYAML + for _, flt := range *filters { + if flt.URL != req.URL { + newFilters = append(newFilters, flt) continue } - deleted = f - path := f.Path() + deleted = flt + path := flt.Path(d.DataDir) err = os.Rename(path, path+".old") if err != nil { log.Error("deleting filter %q: %s", path, err) @@ -162,10 +162,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ } *filters = newFilters - config.Unlock() + d.filtersMu.Unlock() - onConfigModified() - enableFilters(true) + d.ConfigModified() + d.EnableFilters(true) // NOTE: The old files "filter.txt.old" aren't deleted. It's not really // necessary, but will require the additional complicated code to run @@ -191,55 +191,51 @@ type filterURLReq struct { Whitelist bool `json:"whitelist"` } -func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { fj := filterURLReq{} err := json.NewDecoder(r.Body).Decode(&fj) if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err) return } if fj.Data == nil { - err = errors.Error("data cannot be null") - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent")) return } err = validateFilterURL(fj.Data.URL) if err != nil { - err = fmt.Errorf("invalid url: %s", err) - aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err) return } - filt := filter{ + filt := FilterYAML{ Enabled: fj.Data.Enabled, Name: fj.Data.Name, URL: fj.Data.URL, } - status := f.filterSetProperties(fj.URL, filt, fj.Whitelist) + status := d.filterSetProperties(fj.URL, filt, fj.Whitelist) if (status & statusFound) == 0 { - http.Error(w, "URL doesn't exist", http.StatusBadRequest) + aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist") + return } if (status & statusURLExists) != 0 { - http.Error(w, "URL already exists", http.StatusBadRequest) + aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists") + return } - onConfigModified() + d.ConfigModified() restart := (status & statusEnabledChanged) != 0 if (status&statusUpdateRequired) != 0 && fj.Data.Enabled { - // download new filter and apply its rules - flags := filterRefreshBlocklists - if fj.Whitelist { - flags = filterRefreshAllowlists - } - nUpdated, _ := f.refreshFilters(flags, true) + // download new filter and apply its rules. + nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false) // if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically // if not - we restart the filtering ourselves restart = false @@ -249,11 +245,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request } if restart { - enableFilters(true) + d.EnableFilters(true) } } -func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { // This use of ReadAll is safe, because request's body is now limited. body, err := io.ReadAll(r.Body) if err != nil { @@ -262,12 +258,12 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque return } - config.UserRules = strings.Split(string(body), "\n") - onConfigModified() - enableFilters(true) + d.UserRules = strings.Split(string(body), "\n") + d.ConfigModified() + d.EnableFilters(true) } -func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { type Req struct { White bool `json:"whitelist"` } @@ -285,35 +281,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques return } - flags := filterRefreshBlocklists - if req.White { - flags = filterRefreshAllowlists - } - func() { - // Temporarily unlock the Context.controlLock because the - // f.refreshFilters waits for it to be unlocked but it's - // actually locked in ensure wrapper. - // - // TODO(e.burkov): Reconsider this messy syncing process. - Context.controlLock.Unlock() - defer Context.controlLock.Lock() - - resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false) - }() - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) + var ok bool + resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true) + if !ok { + aghhttp.Error( + r, + w, + http.StatusInternalServerError, + "filters update procedure is already running", + ) return } - js, err := json.Marshal(resp) + w.Header().Set("Content-Type", "application/json") + + err = json.NewEncoder(w).Encode(resp) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) return } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(js) } type filterJSON struct { @@ -333,7 +321,7 @@ type filteringConfig struct { Enabled bool `json:"enabled"` } -func filterToJSON(f filter) filterJSON { +func filterToJSON(f FilterYAML) filterJSON { fj := filterJSON{ ID: f.ID, Enabled: f.Enabled, @@ -350,21 +338,21 @@ func filterToJSON(f filter) filterJSON { } // Get filtering configuration -func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { resp := filteringConfig{} - config.RLock() - resp.Enabled = config.DNS.FilteringEnabled - resp.Interval = config.DNS.FiltersUpdateIntervalHours - for _, f := range config.Filters { + d.filtersMu.RLock() + resp.Enabled = d.FilteringEnabled + resp.Interval = d.FiltersUpdateIntervalHours + for _, f := range d.Filters { fj := filterToJSON(f) resp.Filters = append(resp.Filters, fj) } - for _, f := range config.WhitelistFilters { + for _, f := range d.WhitelistFilters { fj := filterToJSON(f) resp.WhitelistFilters = append(resp.WhitelistFilters, fj) } - resp.UserRules = config.UserRules - config.RUnlock() + resp.UserRules = d.UserRules + d.filtersMu.RUnlock() jsonVal, err := json.Marshal(resp) if err != nil { @@ -380,7 +368,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request } // Set filtering configuration -func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) { +func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) { req := filteringConfig{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -389,22 +377,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request return } - if !checkFiltersUpdateIntervalHours(req.Interval) { + if !ValidateUpdateIvl(req.Interval) { aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval") return } func() { - config.Lock() - defer config.Unlock() + d.filtersMu.Lock() + defer d.filtersMu.Unlock() - config.DNS.FilteringEnabled = req.Enabled - config.DNS.FiltersUpdateIntervalHours = req.Interval + d.FilteringEnabled = req.Enabled + d.FiltersUpdateIntervalHours = req.Interval }() - onConfigModified() - enableFilters(true) + d.ConfigModified() + d.EnableFilters(true) } type checkHostRespRule struct { @@ -435,15 +423,15 @@ type checkHostResp struct { FilterID int64 `json:"filter_id"` } -func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { - q := r.URL.Query() - host := q.Get("name") +func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { + host := r.URL.Query().Get("name") - setts := Context.dnsFilter.GetConfig() + setts := d.GetConfig() setts.FilteringEnabled = true setts.ProtectionEnabled = true - Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) - result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) + + d.ApplyBlockedServices(&setts, nil) + result, err := d.CheckHost(host, dns.TypeA, &setts) if err != nil { aghhttp.Error( r, @@ -457,18 +445,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { return } - resp := checkHostResp{} - resp.Reason = result.Reason.String() - resp.SvcName = result.ServiceName - resp.CanonName = result.CanonName - resp.IPList = result.IPList + rulesLen := len(result.Rules) + resp := checkHostResp{ + Reason: result.Reason.String(), + SvcName: result.ServiceName, + CanonName: result.CanonName, + IPList: result.IPList, + Rules: make([]*checkHostRespRule, len(result.Rules)), + } - if len(result.Rules) > 0 { + if rulesLen > 0 { resp.FilterID = result.Rules[0].FilterListID resp.Rule = result.Rules[0].Text } - resp.Rules = make([]*checkHostRespRule, len(result.Rules)) for i, r := range result.Rules { resp.Rules[i] = &checkHostRespRule{ FilterListID: r.FilterListID, @@ -476,28 +466,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { } } - js, err := json.Marshal(resp) - if err != nil { - aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) - - return - } w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(js) + err = json.NewEncoder(w).Encode(resp) + if err != nil { + aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err) + } } // RegisterFilteringHandlers - register handlers -func (f *Filtering) RegisterFilteringHandlers() { - httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus) - httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig) - httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL) - httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL) - httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL) - httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh) - httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules) - httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost) +func (d *DNSFilter) RegisterFilteringHandlers() { + registerHTTP := d.HTTPRegister + if registerHTTP == nil { + return + } + + registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) + registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) + registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus) + + registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable) + registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable) + registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus) + + registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable) + registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable) + registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus) + + registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList) + registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd) + registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete) + + registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices) + registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList) + registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet) + + registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus) + registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig) + registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL) + registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL) + registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL) + registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh) + registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules) + registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost) } -func checkFiltersUpdateIntervalHours(i uint32) bool { +// ValidateUpdateIvl returns false if i is not a valid filters update interval. +func ValidateUpdateIvl(i uint32) bool { return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24 } diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index f8415fbf..c75ea2b9 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { |1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot. ` - f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}}) + f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}}) setts := &Settings{ FilteringEnabled: true, } diff --git a/internal/home/filter.go b/internal/filtering/filter.go similarity index 51% rename from internal/home/filter.go rename to internal/filtering/filter.go index 78abd76a..fcba11aa 100644 --- a/internal/home/filter.go +++ b/internal/filtering/filter.go @@ -1,4 +1,4 @@ -package home +package filtering import ( "bufio" @@ -8,63 +8,29 @@ import ( "net/http" "os" "path/filepath" - "regexp" "strconv" "strings" - "sync" - "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" + "golang.org/x/exp/slices" ) -var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID +// filterDir is the subdirectory of a data directory to store downloaded +// filters. +const filterDir = "filters" -// Filtering - module object -type Filtering struct { - // conf FilteringConf - refreshStatus uint32 // 0:none; 1:in progress - refreshLock sync.Mutex - filterTitleRegexp *regexp.Regexp -} +// nextFilterID is a way to seed a unique ID generation. +// +// TODO(e.burkov): Use more deterministic approach. +var nextFilterID = time.Now().Unix() -// Init - initialize the module -func (f *Filtering) Init() { - f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) - _ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755) - f.loadFilters(config.Filters) - f.loadFilters(config.WhitelistFilters) - deduplicateFilters() - updateUniqueFilterID(config.Filters) - updateUniqueFilterID(config.WhitelistFilters) -} - -// Start - start the module -func (f *Filtering) Start() { - f.RegisterFilteringHandlers() - - // Here we should start updating filters, - // but currently we can't wake up the periodic task to do so. - // So for now we just start this periodic task from here. - go f.periodicallyRefreshFilters() -} - -// Close - close the module -func (f *Filtering) Close() { -} - -func defaultFilters() []filter { - return []filter{ - {Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"}, - {Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"}, - } -} - -// field ordering is important -- yaml fields will mirror ordering from here -type filter struct { +// FilterYAML respresents a filter list in the configuration file. +// +// TODO(e.burkov): Investigate if the field oredering is important. +type FilterYAML struct { Enabled bool URL string // URL or a file path Name string `yaml:"name"` @@ -73,91 +39,108 @@ type filter struct { checksum uint32 // checksum of the file data white bool - filtering.Filter `yaml:",inline"` + Filter `yaml:",inline"` +} + +// Clear filter rules +func (filter *FilterYAML) unload() { + filter.RulesCount = 0 + filter.checksum = 0 +} + +// Path to the filter contents +func (filter *FilterYAML) Path(dataDir string) string { + return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") } const ( - statusFound = 1 - statusEnabledChanged = 2 - statusURLChanged = 4 - statusURLExists = 8 - statusUpdateRequired = 0x10 + statusFound = 1 << iota + statusEnabledChanged + statusURLChanged + statusURLExists + statusUpdateRequired ) // Update properties for a filter specified by its URL // Return status* flags. -func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int { +func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int { r := 0 - config.Lock() - defer config.Unlock() + d.filtersMu.Lock() + defer d.filtersMu.Unlock() - filters := &config.Filters + filters := d.Filters if whitelist { - filters = &config.WhitelistFilters + filters = d.WhitelistFilters } - for i := range *filters { - filt := &(*filters)[i] - if filt.URL != url { - continue + i := slices.IndexFunc(filters, func(filt FilterYAML) bool { + return filt.URL == url + }) + if i == -1 { + return 0 + } + + filt := &filters[i] + + log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled) + filt.Name = newf.Name + + if filt.URL != newf.URL { + r |= statusURLChanged | statusUpdateRequired + if d.filterExistsNoLock(newf.URL) { + return statusURLExists } - log.Debug("filter: set properties: %s: {%s %s %v}", - filt.URL, newf.Name, newf.URL, newf.Enabled) - filt.Name = newf.Name + filt.URL = newf.URL + filt.unload() + filt.LastUpdated = time.Time{} + filt.checksum = 0 + filt.RulesCount = 0 + } - if filt.URL != newf.URL { - r |= statusURLChanged | statusUpdateRequired - if filterExistsNoLock(newf.URL) { - return statusURLExists - } - filt.URL = newf.URL - filt.unload() - filt.LastUpdated = time.Time{} - filt.checksum = 0 - filt.RulesCount = 0 - } + if filt.Enabled != newf.Enabled { + r |= statusEnabledChanged + filt.Enabled = newf.Enabled + if filt.Enabled { + if (r & statusURLChanged) == 0 { + err := d.load(filt) + if err != nil { + // TODO(e.burkov): It seems the error is only returned when + // the file exists and couldn't be open. Investigate and + // improve. + log.Error("loading filter %d: %s", filt.ID, err) - if filt.Enabled != newf.Enabled { - r |= statusEnabledChanged - filt.Enabled = newf.Enabled - if filt.Enabled { - if (r & statusURLChanged) == 0 { - e := f.load(filt) - if e != nil { - // This isn't a fatal error, - // because it may occur when someone removes the file from disk. - filt.LastUpdated = time.Time{} - filt.checksum = 0 - filt.RulesCount = 0 - r |= statusUpdateRequired - } + filt.LastUpdated = time.Time{} + filt.checksum = 0 + filt.RulesCount = 0 + r |= statusUpdateRequired } - } else { - filt.unload() } + } else { + filt.unload() } - - return r | statusFound } - return 0 + + return r | statusFound } // Return TRUE if a filter with this URL exists -func filterExists(url string) bool { - config.RLock() - r := filterExistsNoLock(url) - config.RUnlock() +func (d *DNSFilter) filterExists(url string) bool { + d.filtersMu.RLock() + defer d.filtersMu.RUnlock() + + r := d.filterExistsNoLock(url) + return r } -func filterExistsNoLock(url string) bool { - for _, f := range config.Filters { +func (d *DNSFilter) filterExistsNoLock(url string) bool { + for _, f := range d.Filters { if f.URL == url { return true } } - for _, f := range config.WhitelistFilters { + for _, f := range d.WhitelistFilters { if f.URL == url { return true } @@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool { // Add a filter // Return FALSE if a filter with this URL exists -func filterAdd(f filter) bool { - config.Lock() - defer config.Unlock() +func (d *DNSFilter) filterAdd(flt FilterYAML) bool { + d.filtersMu.Lock() + defer d.filtersMu.Unlock() // Check for duplicates - if filterExistsNoLock(f.URL) { + if d.filterExistsNoLock(flt.URL) { return false } - if f.white { - config.WhitelistFilters = append(config.WhitelistFilters, f) + if flt.white { + d.WhitelistFilters = append(d.WhitelistFilters, flt) } else { - config.Filters = append(config.Filters, f) + d.Filters = append(d.Filters, flt) } return true } // Load filters from the disk // And if any filter has zero ID, assign a new one -func (f *Filtering) loadFilters(array []filter) { +func (d *DNSFilter) loadFilters(array []FilterYAML) { for i := range array { filter := &array[i] // otherwise we're operating on a copy if filter.ID == 0 { @@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) { continue } - err := f.load(filter) + err := d.load(filter) if err != nil { log.Error("Couldn't load filter %d contents due to %s", filter.ID, err) } } } -func deduplicateFilters() { - // Deduplicate filters - i := 0 // output index, used for deletion later - urls := map[string]bool{} - for _, filter := range config.Filters { - if _, ok := urls[filter.URL]; !ok { - // we didn't see it before, keep it - urls[filter.URL] = true // remember the URL - config.Filters[i] = filter - i++ +func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) { + urls := stringutil.NewSet() + lastIdx := 0 + + for _, filter := range filters { + if !urls.Has(filter.URL) { + urls.Add(filter.URL) + filters[lastIdx] = filter + lastIdx++ } } - // all entries we want to keep are at front, delete the rest - config.Filters = config.Filters[:i] + return filters[:lastIdx] } // Set the next filter ID to max(filter.ID) + 1 -func updateUniqueFilterID(filters []filter) { +func updateUniqueFilterID(filters []FilterYAML) { for _, filter := range filters { if nextFilterID < filter.ID { nextFilterID = filter.ID + 1 @@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 { } // Sets up a timer that will be checking for filters updates periodically -func (f *Filtering) periodicallyRefreshFilters() { +func (d *DNSFilter) periodicallyRefreshFilters() { const maxInterval = 1 * 60 * 60 intval := 5 // use a dynamically increasing time interval for { - isNetworkErr := false - if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) { - f.refreshLock.Lock() - _, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists) - f.refreshLock.Unlock() - f.refreshStatus = 0 - if !isNetworkErr { + isNetErr, ok := false, false + if d.FiltersUpdateIntervalHours != 0 { + _, isNetErr, ok = d.tryRefreshFilters(true, true, false) + if ok && !isNetErr { intval = maxInterval } } - if isNetworkErr { + if isNetErr { intval *= 2 if intval > maxInterval { intval = maxInterval @@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() { } } -// Refresh filters -// flags: filterRefresh* -// important: +// tryRefreshFilters is like [refreshFilters], but backs down if the update is +// already going on. // -// TRUE: ignore the fact that we're currently updating the filters -func (f *Filtering) refreshFilters(flags int, important bool) (int, error) { - set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) - if !important && !set { - return 0, fmt.Errorf("filters update procedure is already running") +// TODO(e.burkov): Get rid of the concurrency pattern which requires the +// sync.Mutex.TryLock. +func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) { + if ok = d.refreshLock.TryLock(); !ok { + return 0, false, ok } + defer d.refreshLock.Unlock() - f.refreshLock.Lock() - nUpdated, _ := f.refreshFiltersIfNecessary(flags) - f.refreshLock.Unlock() - f.refreshStatus = 0 - return nUpdated, nil + updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force) + + return updated, isNetworkErr, ok } -func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) { - var updateFilters []filter +// refreshFilters updates the lists and returns the number of updated ones. +// It's safe for concurrent use, but blocks at least until the previous +// refreshing is finished. +func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) { + d.refreshLock.Lock() + defer d.refreshLock.Unlock() + + updated, _ = d.refreshFiltersIntl(block, allow, force) + + return updated +} + +// listsToUpdate returns the slice of filter lists that could be updated. +func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) { + now := time.Now() + + d.filtersMu.RLock() + defer d.filtersMu.RUnlock() + + for i := range *filters { + flt := &(*filters)[i] // otherwise we will be operating on a copy + log.Debug("checking list at index %d: %v", i, flt) + + if !flt.Enabled { + continue + } + + if !force { + exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour) + if now.Before(exp) { + continue + } + } + + toUpd = append(toUpd, FilterYAML{ + Filter: Filter{ + ID: flt.ID, + }, + URL: flt.URL, + Name: flt.Name, + checksum: flt.checksum, + }) + } + + return toUpd +} + +func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) { var updateFlags []bool // 'true' if filter data has changed - now := time.Now() - config.RLock() - for i := range *filters { - f := &(*filters)[i] // otherwise we will be operating on a copy - - if !f.Enabled { - continue - } - - expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60 - if !force && expireTime > now.Unix() { - continue - } - - var uf filter - uf.ID = f.ID - uf.URL = f.URL - uf.Name = f.Name - uf.checksum = f.checksum - updateFilters = append(updateFilters, uf) - } - config.RUnlock() - + updateFilters := d.listsToUpdate(filters, force) if len(updateFilters) == 0 { return 0, nil, nil, false } @@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f nfail := 0 for i := range updateFilters { uf := &updateFilters[i] - updated, err := f.update(uf) + updated, err := d.update(uf) updateFlags = append(updateFlags, updated) if err != nil { nfail++ @@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f uf := &updateFilters[i] updated := updateFlags[i] - config.Lock() + d.filtersMu.Lock() for k := range *filters { f := &(*filters)[k] if f.ID != uf.ID || f.URL != uf.URL { @@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f f.checksum = uf.checksum updateCount++ } - config.Unlock() + d.filtersMu.Unlock() } return updateCount, updateFilters, updateFlags, false } -const ( - filterRefreshForce = 1 // ignore last file modification date - filterRefreshAllowlists = 2 // update allow-lists - filterRefreshBlocklists = 4 // update block-lists -) - -// refreshFiltersIfNecessary checks filters and updates them if necessary. If -// force is true, it ignores the filter.LastUpdated field value. +// refreshFiltersIntl checks filters and updates them if necessary. If force is +// true, it ignores the filter.LastUpdated field value. // // Algorithm: // @@ -378,53 +372,49 @@ const ( // that this method works only on Unix systems. On Windows, don't pass // files to filtering, pass the whole data. // -// refreshFiltersIfNecessary returns the number of updated filters. It also -// returns true if there was a network error and nothing could be updated. +// refreshFiltersIntl returns the number of updated filters. It also returns +// true if there was a network error and nothing could be updated. // // TODO(a.garipov, e.burkov): What the hell? -func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) { - log.Debug("Filters: updating...") +func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) { + log.Debug("filtering: updating...") - updateCount := 0 - var updateFilters []filter - var updateFlags []bool - netError := false - netErrorW := false - force := false - if (flags & filterRefreshForce) != 0 { - force = true + updNum := 0 + var lists []FilterYAML + var toUpd []bool + isNetErr := false + + if block { + updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force) } - if (flags & filterRefreshBlocklists) != 0 { - updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force) + if allow { + updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force) + + updNum += updNumAl + lists = append(lists, listsAl...) + toUpd = append(toUpd, toUpdAl...) + isNetErr = isNetErr || isNetErrAl } - if (flags & filterRefreshAllowlists) != 0 { - updateCountW := 0 - var updateFiltersW []filter - var updateFlagsW []bool - updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force) - updateCount += updateCountW - updateFilters = append(updateFilters, updateFiltersW...) - updateFlags = append(updateFlags, updateFlagsW...) - } - if netError && netErrorW { + if isNetErr { return 0, true } - if updateCount != 0 { - enableFilters(false) + if updNum != 0 { + d.EnableFilters(false) - for i := range updateFilters { - uf := &updateFilters[i] - updated := updateFlags[i] + for i := range lists { + uf := &lists[i] + updated := toUpd[i] if !updated { continue } - _ = os.Remove(uf.Path() + ".old") + _ = os.Remove(uf.Path(d.DataDir) + ".old") } } - log.Debug("Filters: update finished") - return updateCount, false + log.Debug("filtering: update finished") + + return updNum, false } // Allows printable UTF-8 text with CR, LF, TAB characters @@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool { } // A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) -func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) { +func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) { rulesCount := 0 name := "" seenTitle := false @@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) { if len(line) == 0 { // } else if line[0] == '!' { - m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1) + m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1) if len(m) > 0 && len(m[0]) >= 2 && !seenTitle { name = m[0][1] seenTitle = true @@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) { } // Perform upgrade on a filter and update LastUpdated value -func (f *Filtering) update(filter *filter) (bool, error) { - b, err := f.updateIntl(filter) +func (d *DNSFilter) update(filter *FilterYAML) (bool, error) { + b, err := d.updateIntl(filter) filter.LastUpdated = time.Now() if !b { - e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated) + e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated) if e != nil { log.Error("os.Chtimes(): %v", e) } @@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) { return b, err } -func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) { +func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) { htmlTest := true firstChunk := make([]byte, 4*1024) firstChunkLen := 0 @@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in // finalizeUpdate closes and gets rid of temporary file f with filter's content // according to updated. It also saves new values of flt's name, rules number // and checksum if sucсeeded. -func finalizeUpdate( - f *os.File, - flt *filter, +func (d *DNSFilter) finalizeUpdate( + file *os.File, + flt *FilterYAML, updated bool, name string, rnum int, cs uint32, ) (err error) { - tmpFileName := f.Name() + tmpFileName := file.Name() // Close the file before renaming it because it's required on Windows. // // See https://github.com/adguardTeam/adGuardHome/issues/1553. - if err = f.Close(); err != nil { + if err = file.Close(); err != nil { return fmt.Errorf("closing temporary file: %w", err) } @@ -562,9 +552,9 @@ func finalizeUpdate( return os.Remove(tmpFileName) } - log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path()) + log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir)) - if err = os.Rename(tmpFileName, flt.Path()); err != nil { + if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil { return errors.WithDeferred(err, os.Remove(tmpFileName)) } @@ -578,12 +568,12 @@ func finalizeUpdate( // processUpdate copies filter's content from src to dst and returns the name, // rules number, and checksum for it. It also returns the number of bytes read // from src. -func (f *Filtering) processUpdate( +func (d *DNSFilter) processUpdate( src io.Reader, dst *os.File, - flt *filter, + flt *FilterYAML, ) (name string, rnum int, cs uint32, n int, err error) { - if n, err = f.read(src, dst, flt); err != nil { + if n, err = d.read(src, dst, flt); err != nil { return "", 0, 0, 0, err } @@ -591,14 +581,14 @@ func (f *Filtering) processUpdate( return "", 0, 0, 0, err } - rnum, cs, name = f.parseFilterContents(dst) + rnum, cs, name = d.parseFilterContents(dst) return name, rnum, cs, n, nil } // updateIntl updates the flt rewriting it's actual file. It returns true if // the actual update has been performed. -func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { +func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) { log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL) var name string @@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { var cs uint32 var tmpFile *os.File - tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "") + tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "") if err != nil { return false, err } defer func() { - err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)) + err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)) ok = ok && err == nil if ok { log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum) @@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { r = file } else { var resp *http.Response - resp, err = Context.client.Get(flt.URL) + resp, err = d.HTTPClient.Get(flt.URL) if err != nil { log.Printf("requesting filter from %s, skip: %s", flt.URL, err) @@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { r = resp.Body } - name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt) + name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt) return cs != flt.checksum, err } // loads filter contents from the file in dataDir -func (f *Filtering) load(filter *filter) (err error) { - filterFilePath := filter.Path() +func (d *DNSFilter) load(filter *FilterYAML) (err error) { + filterFilePath := filter.Path(d.DataDir) - log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath) + log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath) file, err := os.Open(filterFilePath) if errors.Is(err, os.ErrNotExist) { @@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) { log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size()) - rulesCount, checksum, _ := f.parseFilterContents(file) + rulesCount, checksum, _ := d.parseFilterContents(file) filter.RulesCount = rulesCount filter.checksum = checksum @@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) { return nil } -// Clear filter rules -func (filter *filter) unload() { - filter.RulesCount = 0 - filter.checksum = 0 +func (d *DNSFilter) EnableFilters(async bool) { + d.filtersMu.RLock() + defer d.filtersMu.RUnlock() + + d.enableFiltersLocked(async) } -// Path to the filter contents -func (filter *filter) Path() string { - return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") -} - -func enableFilters(async bool) { - config.RLock() - defer config.RUnlock() - - enableFiltersLocked(async) -} - -func enableFiltersLocked(async bool) { - filters := []filtering.Filter{{ - ID: filtering.CustomListID, - Data: []byte(strings.Join(config.UserRules, "\n")), +func (d *DNSFilter) enableFiltersLocked(async bool) { + filters := []Filter{{ + ID: CustomListID, + Data: []byte(strings.Join(d.UserRules, "\n")), }} - for _, filter := range config.Filters { + for _, filter := range d.Filters { if !filter.Enabled { continue } - filters = append(filters, filtering.Filter{ + filters = append(filters, Filter{ ID: filter.ID, - FilePath: filter.Path(), + FilePath: filter.Path(d.DataDir), }) } - var allowFilters []filtering.Filter - for _, filter := range config.WhitelistFilters { + var allowFilters []Filter + for _, filter := range d.WhitelistFilters { if !filter.Enabled { continue } - allowFilters = append(allowFilters, filtering.Filter{ + allowFilters = append(allowFilters, Filter{ ID: filter.ID, - FilePath: filter.Path(), + FilePath: filter.Path(d.DataDir), }) } - if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil { + if err := d.SetFilters(filters, allowFilters, async); err != nil { log.Debug("enabling filters: %s", err) } - Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled) + d.SetEnabled(d.FilteringEnabled) } diff --git a/internal/home/filter_test.go b/internal/filtering/filter_test.go similarity index 83% rename from internal/home/filter_test.go rename to internal/filtering/filter_test.go index 08290562..b37dd10e 100644 --- a/internal/home/filter_test.go +++ b/internal/filtering/filter_test.go @@ -1,4 +1,4 @@ -package home +package filtering import ( "io/fs" @@ -51,15 +51,17 @@ func TestFilters(t *testing.T) { l := testStartFilterListener(t, &fltContent) - Context = homeContext{ - workDir: t.TempDir(), - client: &http.Client{ + tempDir := t.TempDir() + + filters, err := New(&Config{ + DataDir: tempDir, + HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, - } - Context.filters.Init() + }, nil) + require.NoError(t, err) - f := &filter{ + f := &FilterYAML{ URL: (&url.URL{ Scheme: "http", Host: (&netutil.IPPort{ @@ -71,21 +73,22 @@ func TestFilters(t *testing.T) { } updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) { - ok, err := Context.filters.update(f) + var ok bool + ok, err = filters.update(f) require.NoError(t, err) want(t, ok) assert.Equal(t, wantRulesCount, f.RulesCount) var dir []fs.DirEntry - dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir)) + dir, err = os.ReadDir(filepath.Join(tempDir, filterDir)) require.NoError(t, err) assert.Len(t, dir, 1) - require.FileExists(t, f.Path()) + require.FileExists(t, f.Path(tempDir)) - err = Context.filters.load(f) + err = filters.load(f) require.NoError(t, err) } @@ -105,11 +108,9 @@ func TestFilters(t *testing.T) { }) t.Run("load_unload", func(t *testing.T) { - err := Context.filters.load(f) + err = filters.load(f) require.NoError(t, err) f.unload() }) - - require.NoError(t, os.Remove(f.Path())) } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 446ad4ac..ab884056 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -6,7 +6,10 @@ import ( "fmt" "io/fs" "net" + "net/http" "os" + "path/filepath" + "regexp" "runtime" "runtime/debug" "strings" @@ -24,6 +27,7 @@ import ( "github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" + "golang.org/x/exp/slices" ) // The IDs of built-in filter lists. @@ -69,8 +73,13 @@ type Config struct { // enabled is used to be returned within Settings. // // It is of type uint32 to be accessed by atomic. + // + // TODO(e.burkov): Use atomic.Bool in Go 1.19. enabled uint32 + FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists + FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) + ParentalEnabled bool `yaml:"parental_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` @@ -98,6 +107,24 @@ type Config struct { // CustomResolver is the resolver used by DNSFilter. CustomResolver Resolver `yaml:"-"` + + // HTTPClient is the client to use for updating the remote filters. + HTTPClient *http.Client `yaml:"-"` + + // DataDir is used to store filters' contents. + DataDir string `yaml:"-"` + + // filtersMu protects filter lists. + filtersMu *sync.RWMutex + + // Filters are the blocking filter lists. + Filters []FilterYAML `yaml:"-"` + + // WhitelistFilters are the allowing filter lists. + WhitelistFilters []FilterYAML `yaml:"-"` + + // UserRules is the global list of custom rules. + UserRules []string `yaml:"-"` } // LookupStats store stats collected during safebrowsing or parental checks @@ -128,11 +155,13 @@ type hostChecker struct { // DNSFilter matches hostnames and DNS requests against filtering rules. type DNSFilter struct { - rulesStorage *filterlist.RuleStorage - filteringEngine *urlfilter.DNSEngine + rulesStorage *filterlist.RuleStorage + filteringEngine *urlfilter.DNSEngine + rulesStorageAllow *filterlist.RuleStorage filteringEngineAllow *urlfilter.DNSEngine - engineLock sync.RWMutex + + engineLock sync.RWMutex parentalServer string // access via methods safeBrowsingServer string // access via methods @@ -156,6 +185,12 @@ type DNSFilter struct { // TODO(e.burkov): Use upstream that configured in dnsforward instead. resolver Resolver + refreshLock *sync.Mutex + + // filterTitleRegexp is the regular expression to retrieve a name of a + // filter list. + filterTitleRegexp *regexp.Regexp + hostCheckers []hostChecker } @@ -168,7 +203,7 @@ type Filter struct { Data []byte `yaml:"-"` // ID is automatically assigned when filter is added using nextFilterID. - ID int64 + ID int64 `yaml:"id"` } // Reason holds an enum detailing why it was filtered or not filtered @@ -245,15 +280,7 @@ func (r Reason) String() string { } // In returns true if reasons include r. -func (r Reason) In(reasons ...Reason) (ok bool) { - for _, reason := range reasons { - if r == reason { - return true - } - } - - return false -} +func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) } // SetEnabled sets the status of the *DNSFilter. func (d *DNSFilter) SetEnabled(enabled bool) { @@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) { if enabled { i = 1 } + atomic.StoreUint32(&d.enabled, uint32(i)) } @@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) { // WriteDiskConfig - write configuration func (d *DNSFilter) WriteDiskConfig(c *Config) { - d.confLock.Lock() - defer d.confLock.Unlock() + func() { + d.confLock.Lock() + defer d.confLock.Unlock() - *c = d.Config - c.Rewrites = cloneRewrites(c.Rewrites) + *c = d.Config + c.Rewrites = cloneRewrites(c.Rewrites) + }() + + d.filtersMu.RLock() + defer d.filtersMu.RUnlock() + + c.Filters = slices.Clone(d.Filters) + c.WhitelistFilters = slices.Clone(d.WhitelistFilters) + c.UserRules = slices.Clone(d.UserRules) } // cloneRewrites returns a deep copy of entries. @@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool) } d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task + defer d.filtersInitializerLock.Unlock() + // remove all pending tasks stop := false for !stop { @@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool) } d.filtersInitializerChan <- params - d.filtersInitializerLock.Unlock() return nil } @@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() { func (d *DNSFilter) Close() { d.engineLock.Lock() defer d.engineLock.Unlock() + d.reset() } func (d *DNSFilter) reset() { - var err error - if d.rulesStorage != nil { - err = d.rulesStorage.Close() - if err != nil { + if err := d.rulesStorage.Close(); err != nil { log.Error("filtering: rulesStorage.Close: %s", err) } } if d.rulesStorageAllow != nil { - err = d.rulesStorageAllow.Close() - if err != nil { + if err := d.rulesStorageAllow.Close(); err != nil { log.Error("filtering: rulesStorageAllow.Close: %s", err) } } @@ -885,29 +920,30 @@ func InitModule() { initBlockedServices() } -// New creates properly initialized DNS Filter that is ready to be used. -func New(c *Config, blockFilters []Filter) (d *DNSFilter) { +// New creates properly initialized DNS Filter that is ready to be used. c must +// be non-nil. +func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { d = &DNSFilter{ - resolver: net.DefaultResolver, + resolver: net.DefaultResolver, + refreshLock: &sync.Mutex{}, + filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`), } - if c != nil { - d.safebrowsingCache = cache.New(cache.Config{ - EnableLRU: true, - MaxSize: c.SafeBrowsingCacheSize, - }) - d.safeSearchCache = cache.New(cache.Config{ - EnableLRU: true, - MaxSize: c.SafeSearchCacheSize, - }) - d.parentalCache = cache.New(cache.Config{ - EnableLRU: true, - MaxSize: c.ParentalCacheSize, - }) + d.safebrowsingCache = cache.New(cache.Config{ + EnableLRU: true, + MaxSize: c.SafeBrowsingCacheSize, + }) + d.safeSearchCache = cache.New(cache.Config{ + EnableLRU: true, + MaxSize: c.SafeSearchCacheSize, + }) + d.parentalCache = cache.New(cache.Config{ + EnableLRU: true, + MaxSize: c.ParentalCacheSize, + }) - if c.CustomResolver != nil { - d.resolver = c.CustomResolver - } + if r := c.CustomResolver; r != nil { + d.resolver = r } d.hostCheckers = []hostChecker{{ @@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) { name: "safe search", }} - err := d.initSecurityServices() - if err != nil { - log.Error("filtering: initialize services: %s", err) + defer func() { err = errors.Annotate(err, "filtering: %w") }() - return nil + err = d.initSecurityServices() + if err != nil { + return nil, fmt.Errorf("initializing services: %s", err) } - if c != nil { - d.Config = *c - err = d.prepareRewrites() - if err != nil { - log.Error("rewrites: preparing: %s", err) + d.Config = *c + d.filtersMu = &sync.RWMutex{} - return nil - } + err = d.prepareRewrites() + if err != nil { + return nil, fmt.Errorf("rewrites: preparing: %s", err) } bsvcs := []string{} for _, s := range d.BlockedServices { if !BlockedSvcKnown(s) { log.Debug("skipping unknown blocked-service %q", s) + continue } bsvcs = append(bsvcs, s) @@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) { if blockFilters != nil { err = d.initFiltering(nil, blockFilters) if err != nil { - log.Error("Can't initialize filtering subsystem: %s", err) d.Close() - return nil + + return nil, fmt.Errorf("initializing filtering subsystem: %s", err) } } - return d + _ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755) + + d.loadFilters(d.Filters) + d.loadFilters(d.WhitelistFilters) + + d.Filters = deduplicateFilters(d.Filters) + d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters) + + updateUniqueFilterID(d.Filters) + updateUniqueFilterID(d.WhitelistFilters) + + return d, nil } // Start - start the module: @@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() { d.filtersInitializerChan = make(chan filtersInitializerParams, 1) go d.filtersInitializer() - if d.Config.HTTPRegister != nil { // for tests - d.registerSecurityHandlers() - d.registerRewritesHandlers() - d.registerBlockedServicesHandlers() - } + d.RegisterFilteringHandlers() + + // Here we should start updating filters, + // but currently we can't wake up the periodic task to do so. + // So for now we just start this periodic task from here. + go d.periodicallyRefreshFilters() } diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 95554b07..4fc9182d 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -26,10 +26,6 @@ const ( pcBlocked = "pornhub.com" ) -var setts = Settings{ - ProtectionEnabled: true, -} - // Helpers. func purgeCaches(d *DNSFilter) { @@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) { } } -func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter { - setts = Settings{ +func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) { + setts = &Settings{ ProtectionEnabled: true, FilteringEnabled: true, } @@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter { setts.SafeSearchEnabled = c.SafeSearchEnabled setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled setts.ParentalEnabled = c.ParentalEnabled + } else { + // It must not be nil. + c = &Config{} } - d := New(c, filters) - purgeCaches(d) + f, err := New(c, filters) + require.NoError(t, err) - return d + purgeCaches(f) + + return f, setts } -func (d *DNSFilter) checkMatch(t *testing.T, hostname string) { +func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) { t.Helper() - res, err := d.CheckHost(hostname, dns.TypeA, &setts) + res, err := d.CheckHost(hostname, dns.TypeA, setts) require.NoErrorf(t, err, "host %q", hostname) assert.Truef(t, res.IsFiltered, "host %q", hostname) } -func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) { +func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) { t.Helper() - res, err := d.CheckHost(hostname, qtype, &setts) + res, err := d.CheckHost(hostname, qtype, setts) require.NoErrorf(t, err, "host %q", hostname, err) require.NotEmpty(t, res.Rules, "host %q", hostname) @@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16 assert.Equalf(t, ip, r.IP.String(), "host %q", hostname) } -func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { +func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) { t.Helper() - res, err := d.CheckHost(hostname, dns.TypeA, &setts) + res, err := d.CheckHost(hostname, dns.TypeA, setts) require.NoErrorf(t, err, "host %q", hostname) assert.Falsef(t, res.IsFiltered, "host %q", hostname) @@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) { filters := []Filter{{ ID: 0, Data: []byte(text), }} - d := newForTest(t, nil, filters) + d, setts := newForTest(t, nil, filters) t.Cleanup(d.Close) - d.checkMatchIP(t, "google.com", addr, dns.TypeA) - d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) - d.checkMatchEmpty(t, "subdomain.google.com") - d.checkMatchEmpty(t, "example.org") + d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts) + d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts) + d.checkMatchEmpty(t, "subdomain.google.com", setts) + d.checkMatchEmpty(t, "example.org", setts) // IPv4 match. - d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA) + d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts) // Empty IPv6. - res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) + res, err := d.CheckHost("block.com", dns.TypeAAAA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) { assert.Empty(t, res.Rules[0].IP) // IPv6 match. - d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA) + d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts) // Empty IPv4. - res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) + res, err = d.CheckHost("ipv6.com", dns.TypeA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) { assert.Empty(t, res.Rules[0].IP) // Two IPv4, both must be returned. - res, err = d.CheckHost("host2", dns.TypeA, &setts) + res, err = d.CheckHost("host2", dns.TypeA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) { assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2}) // One IPv6 address. - res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) + res, err = d.CheckHost("host2", dns.TypeAAAA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) { aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogLevel(t, log.DEBUG) - d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) + d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) - d.checkMatch(t, sbBlocked) + d.checkMatch(t, sbBlocked, setts) require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked)) - d.checkMatch(t, "test."+sbBlocked) - d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, pcBlocked) + d.checkMatch(t, "test."+sbBlocked, setts) + d.checkMatchEmpty(t, "yandex.ru", setts) + d.checkMatchEmpty(t, pcBlocked, setts) // Cached result. d.safeBrowsingServer = "127.0.0.1" - d.checkMatch(t, sbBlocked) - d.checkMatchEmpty(t, pcBlocked) + d.checkMatch(t, sbBlocked, setts) + d.checkMatchEmpty(t, pcBlocked, setts) d.safeBrowsingServer = defaultSafebrowsingServer } func TestParallelSB(t *testing.T) { - d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) + d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) @@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Parallel() - d.checkMatch(t, sbBlocked) - d.checkMatch(t, "test."+sbBlocked) - d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, pcBlocked) + d.checkMatch(t, sbBlocked, setts) + d.checkMatch(t, "test."+sbBlocked, setts) + d.checkMatchEmpty(t, "yandex.ru", setts) + d.checkMatchEmpty(t, pcBlocked, setts) }) } }) @@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) { // Safe Search. func TestSafeSearch(t *testing.T) { - d := newForTest(t, &Config{SafeSearchEnabled: true}, nil) + d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) val, ok := d.SafeSearchDomain("www.google.com") require.True(t, ok) @@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) { } func TestCheckHostSafeSearchYandex(t *testing.T) { - d := newForTest(t, &Config{ + d, setts := newForTest(t, &Config{ SafeSearchEnabled: true, }, nil) t.Cleanup(d.Close) @@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { "www.yandex.com", } { t.Run(strings.ToLower(host), func(t *testing.T) { - res, err := d.CheckHost(host, dns.TypeA, &setts) + res, err := d.CheckHost(host, dns.TypeA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) { resolver := &aghtest.TestResolver{} - d := newForTest(t, &Config{ + d, setts := newForTest(t, &Config{ SafeSearchEnabled: true, CustomResolver: resolver, }, nil) @@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { "www.google.je", } { t.Run(host, func(t *testing.T) { - res, err := d.CheckHost(host, dns.TypeA, &setts) + res, err := d.CheckHost(host, dns.TypeA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { } func TestSafeSearchCacheYandex(t *testing.T) { - d := newForTest(t, nil, nil) + d, setts := newForTest(t, nil, nil) t.Cleanup(d.Close) const domain = "yandex.ru" // Check host with disabled safesearch. - res, err := d.CheckHost(domain, dns.TypeA, &setts) + res, err := d.CheckHost(domain, dns.TypeA, setts) require.NoError(t, err) assert.False(t, res.IsFiltered) @@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) { yandexIP := net.IPv4(213, 180, 193, 56) - d = newForTest(t, &Config{SafeSearchEnabled: true}, nil) + d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) - res, err = d.CheckHost(domain, dns.TypeA, &setts) + res, err = d.CheckHost(domain, dns.TypeA, setts) require.NoError(t, err) // For yandex we already know valid IP. @@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) { resolver := &aghtest.TestResolver{} - d := newForTest(t, &Config{ + d, setts := newForTest(t, &Config{ CustomResolver: resolver, }, nil) t.Cleanup(d.Close) const domain = "www.google.ru" - res, err := d.CheckHost(domain, dns.TypeA, &setts) + res, err := d.CheckHost(domain, dns.TypeA, setts) require.NoError(t, err) assert.False(t, res.IsFiltered) require.Empty(t, res.Rules) - d = newForTest(t, &Config{SafeSearchEnabled: true}, nil) + d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil) t.Cleanup(d.Close) d.resolver = resolver @@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } } - res, err = d.CheckHost(domain, dns.TypeA, &setts) + res, err = d.CheckHost(domain, dns.TypeA, setts) require.NoError(t, err) require.Len(t, res.Rules, 1) @@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) { aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogLevel(t, log.DEBUG) - d := newForTest(t, &Config{ParentalEnabled: true}, nil) + d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil) t.Cleanup(d.Close) d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true)) - d.checkMatch(t, pcBlocked) + d.checkMatch(t, pcBlocked, setts) require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked)) - d.checkMatch(t, "www."+pcBlocked) - d.checkMatchEmpty(t, "www.yandex.ru") - d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, "api.jquery.com") + d.checkMatch(t, "www."+pcBlocked, setts) + d.checkMatchEmpty(t, "www.yandex.ru", setts) + d.checkMatchEmpty(t, "yandex.ru", setts) + d.checkMatchEmpty(t, "api.jquery.com", setts) // Test cached result. d.parentalServer = "127.0.0.1" - d.checkMatch(t, pcBlocked) - d.checkMatchEmpty(t, "yandex.ru") + d.checkMatch(t, pcBlocked, setts) + d.checkMatchEmpty(t, "yandex.ru", setts) } // Filtering. @@ -679,10 +680,10 @@ func TestMatching(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) { filters := []Filter{{ID: 0, Data: []byte(tc.rules)}} - d := newForTest(t, nil, filters) + d, setts := newForTest(t, nil, filters) t.Cleanup(d.Close) - res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) + res, err := d.CheckHost(tc.host, tc.wantDNSType, setts) require.NoError(t, err) assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) @@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) { whiteFilters := []Filter{{ ID: 0, Data: []byte(whiteRules), }} - d := newForTest(t, nil, filters) + d, setts := newForTest(t, nil, filters) err := d.SetFilters(filters, whiteFilters, false) require.NoError(t, err) @@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) { t.Cleanup(d.Close) // Matched by white filter. - res, err := d.CheckHost("host1", dns.TypeA, &setts) + res, err := d.CheckHost("host1", dns.TypeA, setts) require.NoError(t, err) assert.False(t, res.IsFiltered) @@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) { assert.Equal(t, "||host1^", res.Rules[0].Text) // Not matched by white filter, but matched by block filter. - res, err = d.CheckHost("host2", dns.TypeA, &setts) + res, err = d.CheckHost("host2", dns.TypeA, setts) require.NoError(t, err) assert.True(t, res.IsFiltered) @@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) { } func TestClientSettings(t *testing.T) { - d := newForTest(t, + d, setts := newForTest(t, &Config{ ParentalEnabled: true, SafeBrowsingEnabled: false, @@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) { return func(t *testing.T) { t.Helper() - r, err := d.CheckHost(tc.host, dns.TypeA, &setts) + r, err := d.CheckHost(tc.host, dns.TypeA, setts) require.NoError(t, err) if before { @@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) { t.Run(tc.name, makeTester(tc, tc.before)) } - applyClientSettings(&setts) + applyClientSettings(setts) for _, tc := range testCases { t.Run(tc.name, makeTester(tc, !tc.before)) @@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) { // Benchmarks. func BenchmarkSafeBrowsing(b *testing.B) { - d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) + d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) for n := 0; n < b.N; n++ { - res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) + res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) require.NoError(b, err) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) @@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) { } func BenchmarkSafeBrowsingParallel(b *testing.B) { - d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) + d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) b.RunParallel(func(pb *testing.PB) { for pb.Next() { - res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) + res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) require.NoError(b, err) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) @@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { } func BenchmarkSafeSearch(b *testing.B) { - d := newForTest(b, &Config{SafeSearchEnabled: true}, nil) + d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil) b.Cleanup(d.Close) for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") @@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) { } func BenchmarkSafeSearchParallel(b *testing.B) { - d := newForTest(b, &Config{SafeSearchEnabled: true}, nil) + d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil) b.Cleanup(d.Close) b.RunParallel(func(pb *testing.PB) { for pb.Next() { diff --git a/internal/filtering/rewrites.go b/internal/filtering/rewrites.go index c1557158..8f0d5ebf 100644 --- a/internal/filtering/rewrites.go +++ b/internal/filtering/rewrites.go @@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) { // 1. A and AAAA > CNAME // 2. wildcard > exact // 3. lower level wildcard > higher level wildcard +// +// TODO(a.garipov): Replace with slices.Sort. type rewritesSorted []*LegacyRewrite -// Len implements the sort.Interface interface for legacyRewritesSorted. +// Len implements the sort.Interface interface for rewritesSorted. func (a rewritesSorted) Len() (l int) { return len(a) } -// Swap implements the sort.Interface interface for legacyRewritesSorted. +// Swap implements the sort.Interface interface for rewritesSorted. func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -// Less implements the sort.Interface interface for legacyRewritesSorted. +// Less implements the sort.Interface interface for rewritesSorted. func (a rewritesSorted) Less(i, j int) (less bool) { - if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME { + ith, jth := a[i], a[j] + if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME { return true - } else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME { + } else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME { return false } - if isWildcard(a[i].Domain) { - if !isWildcard(a[j].Domain) { - return false - } - } else { - if isWildcard(a[j].Domain) { - return true - } + if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw { + return jw } - // Both are wildcards. - return len(a[i].Domain) > len(a[j].Domain) + // Both are either wildcards or not. + return len(ith.Domain) > len(jth.Domain) } // prepareRewrites normalizes and validates all legacy DNS rewrites. @@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) d.Config.ConfigModified() } - -func (d *DNSFilter) registerRewritesHandlers() { - d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList) - d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd) - d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete) -} diff --git a/internal/filtering/rewrites_test.go b/internal/filtering/rewrites_test.go index 5c3de110..17caa167 100644 --- a/internal/filtering/rewrites_test.go +++ b/internal/filtering/rewrites_test.go @@ -12,7 +12,7 @@ import ( // TODO(e.burkov): All the tests in this file may and should me merged together. func TestRewrites(t *testing.T) { - d := newForTest(t, nil, nil) + d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) d.Rewrites = []*LegacyRewrite{{ @@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) { } func TestRewritesLevels(t *testing.T) { - d := newForTest(t, nil, nil) + d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Exact host, wildcard L2, wildcard L3. d.Rewrites = []*LegacyRewrite{{ @@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) { } func TestRewritesExceptionCNAME(t *testing.T) { - d := newForTest(t, nil, nil) + d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Wildcard and exception for a sub-domain. d.Rewrites = []*LegacyRewrite{{ @@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) { } func TestRewritesExceptionIP(t *testing.T) { - d := newForTest(t, nil, nil) + d, _ := newForTest(t, nil, nil) t.Cleanup(d.Close) // Exception for AAAA record. d.Rewrites = []*LegacyRewrite{{ diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index 9d1d0fa4..fe844977 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) } } - -func (d *DNSFilter) registerSecurityHandlers() { - d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) - d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) - d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus) - - d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable) - d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable) - d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus) - - d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable) - d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable) - d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus) -} diff --git a/internal/filtering/safebrowsing_test.go b/internal/filtering/safebrowsing_test.go index f2cc846c..a7abf878 100644 --- a/internal/filtering/safebrowsing_test.go +++ b/internal/filtering/safebrowsing_test.go @@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) { } func TestSBPC_checkErrorUpstream(t *testing.T) { - d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) + d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) ups := aghtest.NewErrorUpstream() @@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) { } func TestSBPC(t *testing.T) { - d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) + d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) const hostname = "example.org" diff --git a/internal/home/config.go b/internal/home/config.go index 47027692..ff597761 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -14,7 +14,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" - "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/dnsproxy/fastip" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -23,10 +22,9 @@ import ( yaml "gopkg.in/yaml.v3" ) -const ( - dataDir = "data" // data storage - filterDir = "filters" // cache location for downloaded filters, it's under DataDir -) +// dataDir is the name of a directory under the working one to store some +// persistent data. +const dataDir = "data" // logSettings are the logging settings part of the configuration file. // @@ -108,9 +106,16 @@ type configuration struct { DNS dnsConfig `yaml:"dns"` TLS tlsConfigSettings `yaml:"tls"` - Filters []filter `yaml:"filters"` - WhitelistFilters []filter `yaml:"whitelist_filters"` - UserRules []string `yaml:"user_rules"` + // Filters reflects the filters from [filtering.Config]. It's cloned to the + // config used in the filtering module at the startup. Afterwards it's + // cloned from the filtering module back here. + // + // TODO(e.burkov): Move all the filtering configuration fields into the + // only configuration subsection covering the changes with a single + // migration. + Filters []filtering.FilterYAML `yaml:"filters"` + WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"` + UserRules []string `yaml:"user_rules"` DHCP *dhcpd.ServerConfig `yaml:"dhcp"` @@ -145,9 +150,7 @@ type dnsConfig struct { dnsforward.FilteringConfig `yaml:",inline"` - FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists - FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) - DnsfilterConf filtering.Config `yaml:",inline"` + DnsfilterConf *filtering.Config `yaml:",inline"` // UpstreamTimeout is the timeout for querying upstream servers. UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"` @@ -193,15 +196,20 @@ type tlsConfigSettings struct { // // TODO(a.garipov, e.burkov): This global is awful and must be removed. var config = &configuration{ - BindPort: 3000, - BetaBindPort: 0, - BindHost: net.IP{0, 0, 0, 0}, - AuthAttempts: 5, - AuthBlockMin: 15, + BindPort: 3000, + BetaBindPort: 0, + BindHost: net.IP{0, 0, 0, 0}, + AuthAttempts: 5, + AuthBlockMin: 15, + WebSessionTTLHours: 30 * 24, DNS: dnsConfig{ - BindHosts: []net.IP{{0, 0, 0, 0}}, - Port: defaultPortDNS, - StatsInterval: 1, + BindHosts: []net.IP{{0, 0, 0, 0}}, + Port: defaultPortDNS, + StatsInterval: 1, + QueryLogEnabled: true, + QueryLogFileEnabled: true, + QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day}, + QueryLogMemSize: 1000, FilteringConfig: dnsforward.FilteringConfig{ ProtectionEnabled: true, // whether or not use any of filtering features BlockingMode: dnsforward.BlockingModeDefault, @@ -222,18 +230,42 @@ var config = &configuration{ // was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257 MaxGoroutines: 300, }, - FilteringEnabled: true, // whether or not use filter lists - FiltersUpdateIntervalHours: 24, - UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, - UsePrivateRDNS: true, + DnsfilterConf: &filtering.Config{ + SafeBrowsingCacheSize: 1 * 1024 * 1024, + SafeSearchCacheSize: 1 * 1024 * 1024, + ParentalCacheSize: 1 * 1024 * 1024, + CacheTime: 30, + FilteringEnabled: true, + FiltersUpdateIntervalHours: 24, + }, + UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, + UsePrivateRDNS: true, }, TLS: tlsConfigSettings{ PortHTTPS: defaultPortHTTPS, PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy PortDNSOverQUIC: defaultPortQUIC, }, + Filters: []filtering.FilterYAML{{ + Filter: filtering.Filter{ID: 1}, + Enabled: true, + URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", + Name: "AdGuard DNS filter", + }, { + Filter: filtering.Filter{ID: 2}, + Enabled: false, + URL: "https://adaway.org/hosts.txt", + Name: "AdAway Default Blocklist", + }}, DHCP: &dhcpd.ServerConfig{ LocalDomainName: "lan", + Conf4: dhcpd.V4ServerConf{ + LeaseDuration: dhcpd.DefaultDHCPLeaseTTL, + ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP, + }, + Conf6: dhcpd.V6ServerConf{ + LeaseDuration: dhcpd.DefaultDHCPLeaseTTL, + }, }, Clients: &clientsConfig{ Sources: &clientSourcesConf{ @@ -255,31 +287,6 @@ var config = &configuration{ SchemaVersion: currentSchemaVersion, } -// initConfig initializes default configuration for the current OS&ARCH -func initConfig() { - config.WebSessionTTLHours = 30 * 24 - - config.DNS.QueryLogEnabled = true - config.DNS.QueryLogFileEnabled = true - config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day} - config.DNS.QueryLogMemSize = 1000 - - config.DNS.CacheSize = 4 * 1024 * 1024 - config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024 - config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024 - config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024 - config.DNS.DnsfilterConf.CacheTime = 30 - config.Filters = defaultFilters() - - config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL - config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP - config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL - - if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment { - config.BetaBindPort = 3001 - } -} - // getConfigFilename returns path to the current config file func (c *configuration) getConfigFilename() string { configFile, err := filepath.EvalSymlinks(Context.configFilename) @@ -348,8 +355,8 @@ func parseConfig() (err error) { return fmt.Errorf("validating udp ports: %w", err) } - if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { - config.DNS.FiltersUpdateIntervalHours = 24 + if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) { + config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24 } if config.DNS.UpstreamTimeout.Duration == 0 { @@ -418,10 +425,11 @@ func (c *configuration) write() (err error) { config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP } - if Context.dnsFilter != nil { - c := filtering.Config{} - Context.dnsFilter.WriteDiskConfig(&c) - config.DNS.DnsfilterConf = c + if Context.filters != nil { + Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf) + config.Filters = config.DNS.DnsfilterConf.Filters + config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters + config.UserRules = config.DNS.DnsfilterConf.UserRules } if s := Context.dnsServer; s != nil { diff --git a/internal/home/control.go b/internal/home/control.go index 54d1652a..829063e9 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -291,7 +291,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { } httpsURL := &url.URL{ - Scheme: schemeHTTPS, + Scheme: aghhttp.SchemeHTTPS, Host: hostPort, Path: r.URL.Path, RawQuery: r.URL.RawQuery, @@ -307,7 +307,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) { // // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin. originURL := &url.URL{ - Scheme: schemeHTTP, + Scheme: aghhttp.SchemeHTTP, Host: r.Host, } w.Header().Set("Access-Control-Allow-Origin", originURL.String()) diff --git a/internal/home/dns.go b/internal/home/dns.go index 88ae8ef2..06c38bcc 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -31,7 +31,10 @@ const ( // Called by other modules when configuration is changed func onConfigModified() { - _ = config.write() + err := config.write() + if err != nil { + log.Error("writing config: %s", err) + } } // initDNSServer creates an instance of the dnsforward.Server @@ -71,11 +74,11 @@ func initDNSServer() (err error) { } Context.queryLog = querylog.New(conf) - filterConf := config.DNS.DnsfilterConf - filterConf.EtcHosts = Context.etcHosts - filterConf.ConfigModified = onConfigModified - filterConf.HTTPRegister = httpRegister - Context.dnsFilter = filtering.New(&filterConf, nil) + Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil) + if err != nil { + // Don't wrap the error, since it's informative enough as is. + return err + } var privateNets netutil.SubnetSet switch len(config.DNS.PrivateNets) { @@ -83,13 +86,10 @@ func initDNSServer() (err error) { // Use an optimized locally-served matcher. privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) case 1: - var n *net.IPNet - n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) + privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) if err != nil { return fmt.Errorf("preparing the set of private subnets: %w", err) } - - privateNets = n default: var nets []*net.IPNet nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) @@ -101,15 +101,13 @@ func initDNSServer() (err error) { } p := dnsforward.DNSCreateParams{ - DNSFilter: Context.dnsFilter, + DNSFilter: Context.filters, Stats: Context.stats, QueryLog: Context.queryLog, PrivateNets: privateNets, Anonymizer: anonymizer, LocalDomain: config.DHCP.LocalDomainName, - } - if Context.dhcpServer != nil { - p.DHCPServer = Context.dhcpServer + DHCPServer: Context.dhcpServer, } Context.dnsServer, err = dnsforward.NewServer(p) @@ -143,7 +141,6 @@ func initDNSServer() (err error) { Context.whois = initWHOIS(&Context.clients) } - Context.filters.Init() return nil } @@ -335,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) { // applyAdditionalFiltering adds additional client information and settings if // the client has them. func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) { - Context.dnsFilter.ApplyBlockedServices(setts, nil, true) + // pref is a prefix for logging messages around the scope. + const pref = "applying filters" - log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID) + Context.filters.ApplyBlockedServices(setts, nil) + + log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID) if clientIP == nil { return @@ -349,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering if !ok { c, ok = Context.clients.Find(clientIP.String()) if !ok { - log.Debug("client with ip %s and clientid %q not found", clientIP, clientID) + log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) return } } - log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID) + log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID) if c.UseOwnBlockedServices { - Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false) + Context.filters.ApplyBlockedServices(setts, c.BlockedServices) } setts.ClientName = c.Name @@ -381,7 +381,7 @@ func startDNSServer() error { return fmt.Errorf("unable to start forwarding DNS server: Already running") } - enableFiltersLocked(false) + Context.filters.EnableFilters(false) Context.clients.Start() @@ -390,7 +390,6 @@ func startDNSServer() error { return fmt.Errorf("couldn't start forwarding DNS server: %w", err) } - Context.dnsFilter.Start() Context.filters.Start() Context.stats.Start() Context.queryLog.Start() @@ -449,10 +448,7 @@ func closeDNSServer() { Context.dnsServer = nil } - if Context.dnsFilter != nil { - Context.dnsFilter.Close() - Context.dnsFilter = nil - } + Context.filters.Close() if Context.stats != nil { err := Context.stats.Close() @@ -469,7 +465,5 @@ func closeDNSServer() { Context.queryLog = nil } - Context.filters.Close() - - log.Debug("Closed all DNS modules") + log.Debug("all dns modules are closed") } diff --git a/internal/home/home.go b/internal/home/home.go index 4b200bc1..76f4ac82 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -20,6 +20,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" @@ -33,6 +34,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "golang.org/x/exp/slices" "gopkg.in/natefinch/lumberjack.v2" ) @@ -52,10 +54,9 @@ type homeContext struct { dnsServer *dnsforward.Server // DNS module rdns *RDNS // rDNS module whois *WHOIS // WHOIS module - dnsFilter *filtering.DNSFilter // DNS filtering module dhcpServer dhcpd.Interface // DHCP module auth *Auth // HTTP authentication module - filters Filtering // DNS filtering module + filters *filtering.DNSFilter // DNS filtering module web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module // etcHosts is an IP-hostname pairs set taken from system configuration @@ -140,7 +141,12 @@ func setupContext(args options) { checkPermissions() } - initConfig() + switch version.Channel() { + case version.ChannelEdge, version.ChannelDevelopment: + config.BetaBindPort = 3001 + default: + // Go on. + } Context.tlsRoots = LoadSystemRootCAs() Context.transport = &http.Transport{ @@ -265,6 +271,14 @@ func setupHostsContainer() (err error) { } func setupConfig(args options) (err error) { + config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts + config.DNS.DnsfilterConf.ConfigModified = onConfigModified + config.DNS.DnsfilterConf.HTTPRegister = httpRegister + config.DNS.DnsfilterConf.DataDir = Context.getDataDir() + config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters) + config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters) + config.DNS.DnsfilterConf.HTTPClient = Context.client + config.DHCP.WorkDir = Context.workDir config.DHCP.HTTPRegister = httpRegister config.DHCP.ConfigModified = onConfigModified @@ -384,8 +398,6 @@ func fatalOnError(err error) { // run configures and starts AdGuard Home. func run(args options, clientBuildFS fs.FS) { - var err error - // configure config filename initConfigFilename(args) @@ -404,7 +416,7 @@ func run(args options, clientBuildFS fs.FS) { setupContext(args) - err = configureOS(config) + err := configureOS(config) fatalOnError(err) // clients package uses filtering package's static data (filtering.BlockedSvcKnown()), @@ -763,12 +775,12 @@ func printHTTPAddresses(proto string) { } port := config.BindPort - if proto == schemeHTTPS { + if proto == aghhttp.SchemeHTTPS { port = tlsConf.PortHTTPS } // TODO(e.burkov): Inspect and perhaps merge with the previous condition. - if proto == schemeHTTPS && tlsConf.ServerName != "" { + if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" { printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0) return diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index 40094a6a..e2f7283f 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -8,6 +8,7 @@ import ( "net/url" "path" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) { case dnsProtoHTTPS: dspName = fmt.Sprintf("%s DoH", d.ServerName) u := &url.URL{ - Scheme: schemeHTTPS, + Scheme: aghhttp.SchemeHTTPS, Host: d.ServerName, Path: path.Join("/dns-query", clientID), } diff --git a/internal/home/service.go b/internal/home/service.go index 20367718..c670ebe2 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/errors" @@ -277,7 +278,7 @@ AdGuard Home is successfully installed and will automatically start on boot. There are a few more things that must be configured before you can use it. Click on the link below and follow the Installation Wizard steps to finish setup. AdGuard Home is now available at the following addresses:`) - printHTTPAddresses(schemeHTTP) + printHTTPAddresses(aghhttp.SchemeHTTP) } } diff --git a/internal/home/upgrade_test.go b/internal/home/upgrade_test.go index a5267032..949dac5f 100644 --- a/internal/home/upgrade_test.go +++ b/internal/home/upgrade_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" @@ -160,7 +161,7 @@ func assertEqualExcept(t *testing.T, oldConf, newConf yobj, oldKeys, newKeys []s } func testDiskConf(schemaVersion int) (diskConf yobj) { - filters := []filter{{ + filters := []filtering.FilterYAML{{ URL: "https://filters.adtidy.org/android/filters/111_optimized.txt", Name: "Latvian filter", RulesCount: 100, diff --git a/internal/home/web.go b/internal/home/web.go index 2052df55..5a26de59 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/golibs/errors" @@ -19,12 +20,6 @@ import ( "golang.org/x/net/http2/h2c" ) -// HTTP scheme constants. -const ( - schemeHTTP = "http" - schemeHTTPS = "https" -) - const ( // readTimeout is the maximum duration for reading the entire request, // including the body. @@ -166,7 +161,7 @@ func (web *Web) Start() { // this loop is used as an ability to change listening host and/or port for !web.httpsServer.shutdown { - printHTTPAddresses(schemeHTTP) + printHTTPAddresses(aghhttp.SchemeHTTP) errs := make(chan error, 2) // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. @@ -286,7 +281,7 @@ func (web *Web) tlsServerLoop() { WriteTimeout: web.conf.WriteTimeout, } - printHTTPAddresses(schemeHTTPS) + printHTTPAddresses(aghhttp.SchemeHTTPS) err := web.httpsServer.server.ListenAndServeTLS("", "") if err != http.ErrServerClosed { cleanupAlways()