Pull request: 4871 imp filtering
Merge in DNS/adguard-home from 4871-imp-filtering to master
Closes #4871.
Squashed commit of the following:
commit 618e7c558447703c114332708c94ef1b34362cf9
Merge: 41ff8ab7 11e4f091
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 19:27:08 2022 +0300
Merge branch 'master' into 4871-imp-filtering
commit 41ff8ab755a87170e7334dedcae00f01dcca238a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 19:26:11 2022 +0300
filtering: imp code, log
commit e4ae1d1788406ffd7ef0fcc6df896a22b0c2db37
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 14:11:07 2022 +0300
filtering: move handlers into single func
commit f7a340b4c10980f512ae935a156f02b0133a1627
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Sep 21 19:21:09 2022 +0300
all: imp code
commit e064bf4d3de0283e4bda2aaf5b9822bb8a08f4a6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 20:12:16 2022 +0300
all: imp name
commit e7eda3905762f0821e1be1ac3cf77e0ecbedeff4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:51:23 2022 +0300
all: finally get rid of filtering
commit 188550d873e625cc2951583bb3a2eaad036745f5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:36:03 2022 +0300
filtering: merge refresh
commit e54ed9c7952b17e66b790c835269b28fbc26f9ca
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:16:23 2022 +0300
filtering: merge filters
commit 32da31b754a319487d5f9d5e81e607d349b90180
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 14:48:13 2022 +0300
filtering: imp docs
commit 43b0cafa7a27bb9b620c2ba50ccdddcf32cfcecc
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 14:38:04 2022 +0300
all: imp code
commit 253a2ea6c92815d364546e34d631e406dd604644
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 20:43:15 2022 +0300
filtering: rm important flag
commit 1b87f08f946389d410f13412c7e486290d5e752d
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 17:05:40 2022 +0300
all: move filtering to the package
commit daa13499f1dd4fe475c4b75769e34f1eb0915bdf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 15:13:55 2022 +0300
all: finish merging
commit d6db75eb2e1f23528e9200ea51507eb793eefa3c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Sep 16 18:18:14 2022 +0300
all: continue merging
commit 45b4c484deb7198a469aa18d719bb9dbe81e5d22
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Sep 14 15:44:22 2022 +0300
all: merge filtering types
This commit is contained in:
parent
11e4f09165
commit
47c9c946a3
|
@ -9,6 +9,12 @@ import (
|
|||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
const (
|
||||
SchemeHTTP = "http"
|
||||
SchemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
|
|
|
@ -67,10 +67,11 @@ func createTestServer(
|
|||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(filterConf, filters)
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
|
@ -774,7 +775,9 @@ func TestBlockedCustomIP(t *testing.T) {
|
|||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
|
@ -906,7 +909,9 @@ func TestRewrite(t *testing.T) {
|
|||
Type: dns.TypeCNAME,
|
||||
}},
|
||||
}
|
||||
f := filtering.New(c, nil)
|
||||
f, err := filtering.New(c, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
|
@ -1021,19 +1026,14 @@ var testDHCP = &dhcpd.MockInterface{
|
|||
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
|
||||
}
|
||||
|
||||
// func (*testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
// return []*dhcpd.Lease{{
|
||||
// IP: net.IP{192, 168, 12, 34},
|
||||
// HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
// Hostname: "myhost",
|
||||
// }}
|
||||
// }
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DNSFilter: flt,
|
||||
DHCPServer: testDHCP,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
|
@ -1100,9 +1100,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
|||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
})
|
||||
|
||||
flt := filtering.New(&filtering.Config{
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var s *Server
|
||||
|
|
|
@ -35,7 +35,8 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
|||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
|
|
|
@ -421,31 +421,34 @@ func initBlockedServices() {
|
|||
}
|
||||
|
||||
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||
func BlockedSvcKnown(s string) bool {
|
||||
_, ok := serviceRules[s]
|
||||
func BlockedSvcKnown(s string) (ok bool) {
|
||||
_, ok = serviceRules[s]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) {
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
|
||||
setts.ServicesRules = []ServiceEntry{}
|
||||
if global {
|
||||
if list == nil {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
list = d.Config.BlockedServices
|
||||
}
|
||||
|
||||
for _, name := range list {
|
||||
rules, ok := serviceRules[name]
|
||||
|
||||
if !ok {
|
||||
log.Error("unknown service name: %s", name)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
s := ServiceEntry{}
|
||||
s.Name = name
|
||||
s.Rules = rules
|
||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
||||
setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{
|
||||
Name: name,
|
||||
Rules: rules,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
|||
|
||||
d.ConfigModified()
|
||||
}
|
||||
|
||||
// registerBlockedServicesHandlers - register HTTP handlers
|
||||
func (d *DNSFilter) registerBlockedServicesHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
@ -34,7 +34,7 @@ func validateFilterURL(urlStr string) (err error) {
|
|||
return fmt.Errorf("checking filter url: %w", err)
|
||||
}
|
||||
|
||||
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
|
||||
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||
return fmt.Errorf("checking filter url: invalid scheme %q", s)
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ type filterAddJSON struct {
|
|||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterAddJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
|
@ -65,14 +65,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
// Check for duplicates
|
||||
if filterExists(fj.URL) {
|
||||
if d.filterExists(fj.URL) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set necessary properties
|
||||
filt := filter{
|
||||
filt := FilterYAML{
|
||||
Enabled: true,
|
||||
URL: fj.URL,
|
||||
Name: fj.Name,
|
||||
|
@ -81,7 +81,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
|||
filt.ID = assignUniqueFilterID()
|
||||
|
||||
// Download the filter contents
|
||||
ok, err := f.update(&filt)
|
||||
ok, err := d.update(&filt)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
|
@ -109,14 +109,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
|||
|
||||
// URL is assumed valid so append it to filters, update config, write new
|
||||
// file and reload it to engines.
|
||||
if !filterAdd(filt) {
|
||||
if !d.filterAdd(filt) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
if err != nil {
|
||||
|
@ -124,7 +124,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
|
@ -138,23 +138,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
|||
return
|
||||
}
|
||||
|
||||
config.Lock()
|
||||
filters := &config.Filters
|
||||
d.filtersMu.Lock()
|
||||
filters := &d.Filters
|
||||
if req.Whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
filters = &d.WhitelistFilters
|
||||
}
|
||||
|
||||
var deleted filter
|
||||
var newFilters []filter
|
||||
for _, f := range *filters {
|
||||
if f.URL != req.URL {
|
||||
newFilters = append(newFilters, f)
|
||||
var deleted FilterYAML
|
||||
var newFilters []FilterYAML
|
||||
for _, flt := range *filters {
|
||||
if flt.URL != req.URL {
|
||||
newFilters = append(newFilters, flt)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
deleted = f
|
||||
path := f.Path()
|
||||
deleted = flt
|
||||
path := flt.Path(d.DataDir)
|
||||
err = os.Rename(path, path+".old")
|
||||
if err != nil {
|
||||
log.Error("deleting filter %q: %s", path, err)
|
||||
|
@ -162,10 +162,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
|||
}
|
||||
|
||||
*filters = newFilters
|
||||
config.Unlock()
|
||||
d.filtersMu.Unlock()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
|
||||
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
|
||||
// necessary, but will require the additional complicated code to run
|
||||
|
@ -191,55 +191,51 @@ type filterURLReq struct {
|
|||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterURLReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if fj.Data == nil {
|
||||
err = errors.Error("data cannot be null")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = validateFilterURL(fj.Data.URL)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
filt := filter{
|
||||
filt := FilterYAML{
|
||||
Enabled: fj.Data.Enabled,
|
||||
Name: fj.Data.Name,
|
||||
URL: fj.Data.URL,
|
||||
}
|
||||
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
status := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
if (status & statusFound) == 0 {
|
||||
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist")
|
||||
|
||||
return
|
||||
}
|
||||
if (status & statusURLExists) != 0 {
|
||||
http.Error(w, "URL already exists", http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
d.ConfigModified()
|
||||
|
||||
restart := (status & statusEnabledChanged) != 0
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := filterRefreshBlocklists
|
||||
if fj.Whitelist {
|
||||
flags = filterRefreshAllowlists
|
||||
}
|
||||
nUpdated, _ := f.refreshFilters(flags, true)
|
||||
// download new filter and apply its rules.
|
||||
nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false)
|
||||
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
||||
// if not - we restart the filtering ourselves
|
||||
restart = false
|
||||
|
@ -249,11 +245,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
|
@ -262,12 +258,12 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
|
|||
return
|
||||
}
|
||||
|
||||
config.UserRules = strings.Split(string(body), "\n")
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.UserRules = strings.Split(string(body), "\n")
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
type Req struct {
|
||||
White bool `json:"whitelist"`
|
||||
}
|
||||
|
@ -285,35 +281,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
|||
return
|
||||
}
|
||||
|
||||
flags := filterRefreshBlocklists
|
||||
if req.White {
|
||||
flags = filterRefreshAllowlists
|
||||
}
|
||||
func() {
|
||||
// Temporarily unlock the Context.controlLock because the
|
||||
// f.refreshFilters waits for it to be unlocked but it's
|
||||
// actually locked in ensure wrapper.
|
||||
//
|
||||
// TODO(e.burkov): Reconsider this messy syncing process.
|
||||
Context.controlLock.Unlock()
|
||||
defer Context.controlLock.Lock()
|
||||
|
||||
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
|
||||
}()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
var ok bool
|
||||
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
|
||||
if !ok {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"filters update procedure is already running",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
|
@ -333,7 +321,7 @@ type filteringConfig struct {
|
|||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
func filterToJSON(f FilterYAML) filterJSON {
|
||||
fj := filterJSON{
|
||||
ID: f.ID,
|
||||
Enabled: f.Enabled,
|
||||
|
@ -350,21 +338,21 @@ func filterToJSON(f filter) filterJSON {
|
|||
}
|
||||
|
||||
// Get filtering configuration
|
||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := filteringConfig{}
|
||||
config.RLock()
|
||||
resp.Enabled = config.DNS.FilteringEnabled
|
||||
resp.Interval = config.DNS.FiltersUpdateIntervalHours
|
||||
for _, f := range config.Filters {
|
||||
d.filtersMu.RLock()
|
||||
resp.Enabled = d.FilteringEnabled
|
||||
resp.Interval = d.FiltersUpdateIntervalHours
|
||||
for _, f := range d.Filters {
|
||||
fj := filterToJSON(f)
|
||||
resp.Filters = append(resp.Filters, fj)
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
for _, f := range d.WhitelistFilters {
|
||||
fj := filterToJSON(f)
|
||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||
}
|
||||
resp.UserRules = config.UserRules
|
||||
config.RUnlock()
|
||||
resp.UserRules = d.UserRules
|
||||
d.filtersMu.RUnlock()
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
|
@ -380,7 +368,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
// Set filtering configuration
|
||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := filteringConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
|
@ -389,22 +377,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
||||
if !ValidateUpdateIvl(req.Interval) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
config.DNS.FilteringEnabled = req.Enabled
|
||||
config.DNS.FiltersUpdateIntervalHours = req.Interval
|
||||
d.FilteringEnabled = req.Enabled
|
||||
d.FiltersUpdateIntervalHours = req.Interval
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
type checkHostRespRule struct {
|
||||
|
@ -435,15 +423,15 @@ type checkHostResp struct {
|
|||
FilterID int64 `json:"filter_id"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
host := q.Get("name")
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.URL.Query().Get("name")
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts := d.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
setts.ProtectionEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
|
||||
d.ApplyBlockedServices(&setts, nil)
|
||||
result, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
|
@ -457,18 +445,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
resp := checkHostResp{}
|
||||
resp.Reason = result.Reason.String()
|
||||
resp.SvcName = result.ServiceName
|
||||
resp.CanonName = result.CanonName
|
||||
resp.IPList = result.IPList
|
||||
rulesLen := len(result.Rules)
|
||||
resp := checkHostResp{
|
||||
Reason: result.Reason.String(),
|
||||
SvcName: result.ServiceName,
|
||||
CanonName: result.CanonName,
|
||||
IPList: result.IPList,
|
||||
Rules: make([]*checkHostRespRule, len(result.Rules)),
|
||||
}
|
||||
|
||||
if len(result.Rules) > 0 {
|
||||
if rulesLen > 0 {
|
||||
resp.FilterID = result.Rules[0].FilterListID
|
||||
resp.Rule = result.Rules[0].Text
|
||||
}
|
||||
|
||||
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
|
||||
for i, r := range result.Rules {
|
||||
resp.Rules[i] = &checkHostRespRule{
|
||||
FilterListID: r.FilterListID,
|
||||
|
@ -476,28 +466,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (f *Filtering) RegisterFilteringHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
|
||||
func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
registerHTTP := d.HTTPRegister
|
||||
if registerHTTP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||
registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||
registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||
registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules)
|
||||
registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost)
|
||||
}
|
||||
|
||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
||||
// ValidateUpdateIvl returns false if i is not a valid filters update interval.
|
||||
func ValidateUpdateIvl(i uint32) bool {
|
||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||
}
|
|
@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
|||
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
||||
`
|
||||
|
||||
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
setts := &Settings{
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
|
@ -8,63 +8,29 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
||||
// filterDir is the subdirectory of a data directory to store downloaded
|
||||
// filters.
|
||||
const filterDir = "filters"
|
||||
|
||||
// Filtering - module object
|
||||
type Filtering struct {
|
||||
// conf FilteringConf
|
||||
refreshStatus uint32 // 0:none; 1:in progress
|
||||
refreshLock sync.Mutex
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
}
|
||||
// nextFilterID is a way to seed a unique ID generation.
|
||||
//
|
||||
// TODO(e.burkov): Use more deterministic approach.
|
||||
var nextFilterID = time.Now().Unix()
|
||||
|
||||
// Init - initialize the module
|
||||
func (f *Filtering) Init() {
|
||||
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755)
|
||||
f.loadFilters(config.Filters)
|
||||
f.loadFilters(config.WhitelistFilters)
|
||||
deduplicateFilters()
|
||||
updateUniqueFilterID(config.Filters)
|
||||
updateUniqueFilterID(config.WhitelistFilters)
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (f *Filtering) Start() {
|
||||
f.RegisterFilteringHandlers()
|
||||
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go f.periodicallyRefreshFilters()
|
||||
}
|
||||
|
||||
// Close - close the module
|
||||
func (f *Filtering) Close() {
|
||||
}
|
||||
|
||||
func defaultFilters() []filter {
|
||||
return []filter{
|
||||
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
|
||||
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
|
||||
}
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
// FilterYAML respresents a filter list in the configuration file.
|
||||
//
|
||||
// TODO(e.burkov): Investigate if the field oredering is important.
|
||||
type FilterYAML struct {
|
||||
Enabled bool
|
||||
URL string // URL or a file path
|
||||
Name string `yaml:"name"`
|
||||
|
@ -73,44 +39,58 @@ type filter struct {
|
|||
checksum uint32 // checksum of the file data
|
||||
white bool
|
||||
|
||||
filtering.Filter `yaml:",inline"`
|
||||
Filter `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *FilterYAML) unload() {
|
||||
filter.RulesCount = 0
|
||||
filter.checksum = 0
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *FilterYAML) Path(dataDir string) string {
|
||||
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
const (
|
||||
statusFound = 1
|
||||
statusEnabledChanged = 2
|
||||
statusURLChanged = 4
|
||||
statusURLExists = 8
|
||||
statusUpdateRequired = 0x10
|
||||
statusFound = 1 << iota
|
||||
statusEnabledChanged
|
||||
statusURLChanged
|
||||
statusURLExists
|
||||
statusUpdateRequired
|
||||
)
|
||||
|
||||
// Update properties for a filter specified by its URL
|
||||
// Return status* flags.
|
||||
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
|
||||
func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int {
|
||||
r := 0
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
filters := &config.Filters
|
||||
filters := d.Filters
|
||||
if whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
filters = d.WhitelistFilters
|
||||
}
|
||||
|
||||
for i := range *filters {
|
||||
filt := &(*filters)[i]
|
||||
if filt.URL != url {
|
||||
continue
|
||||
i := slices.IndexFunc(filters, func(filt FilterYAML) bool {
|
||||
return filt.URL == url
|
||||
})
|
||||
if i == -1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
log.Debug("filter: set properties: %s: {%s %s %v}",
|
||||
filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||
filt := &filters[i]
|
||||
|
||||
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||
filt.Name = newf.Name
|
||||
|
||||
if filt.URL != newf.URL {
|
||||
r |= statusURLChanged | statusUpdateRequired
|
||||
if filterExistsNoLock(newf.URL) {
|
||||
if d.filterExistsNoLock(newf.URL) {
|
||||
return statusURLExists
|
||||
}
|
||||
|
||||
filt.URL = newf.URL
|
||||
filt.unload()
|
||||
filt.LastUpdated = time.Time{}
|
||||
|
@ -123,10 +103,13 @@ func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool)
|
|||
filt.Enabled = newf.Enabled
|
||||
if filt.Enabled {
|
||||
if (r & statusURLChanged) == 0 {
|
||||
e := f.load(filt)
|
||||
if e != nil {
|
||||
// This isn't a fatal error,
|
||||
// because it may occur when someone removes the file from disk.
|
||||
err := d.load(filt)
|
||||
if err != nil {
|
||||
// TODO(e.burkov): It seems the error is only returned when
|
||||
// the file exists and couldn't be open. Investigate and
|
||||
// improve.
|
||||
log.Error("loading filter %d: %s", filt.ID, err)
|
||||
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
|
@ -139,25 +122,25 @@ func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool)
|
|||
}
|
||||
|
||||
return r | statusFound
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Return TRUE if a filter with this URL exists
|
||||
func filterExists(url string) bool {
|
||||
config.RLock()
|
||||
r := filterExistsNoLock(url)
|
||||
config.RUnlock()
|
||||
func (d *DNSFilter) filterExists(url string) bool {
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
r := d.filterExistsNoLock(url)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func filterExistsNoLock(url string) bool {
|
||||
for _, f := range config.Filters {
|
||||
func (d *DNSFilter) filterExistsNoLock(url string) bool {
|
||||
for _, f := range d.Filters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
for _, f := range d.WhitelistFilters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
|
@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool {
|
|||
|
||||
// Add a filter
|
||||
// Return FALSE if a filter with this URL exists
|
||||
func filterAdd(f filter) bool {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
// Check for duplicates
|
||||
if filterExistsNoLock(f.URL) {
|
||||
if d.filterExistsNoLock(flt.URL) {
|
||||
return false
|
||||
}
|
||||
|
||||
if f.white {
|
||||
config.WhitelistFilters = append(config.WhitelistFilters, f)
|
||||
if flt.white {
|
||||
d.WhitelistFilters = append(d.WhitelistFilters, flt)
|
||||
} else {
|
||||
config.Filters = append(config.Filters, f)
|
||||
d.Filters = append(d.Filters, flt)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Load filters from the disk
|
||||
// And if any filter has zero ID, assign a new one
|
||||
func (f *Filtering) loadFilters(array []filter) {
|
||||
func (d *DNSFilter) loadFilters(array []FilterYAML) {
|
||||
for i := range array {
|
||||
filter := &array[i] // otherwise we're operating on a copy
|
||||
if filter.ID == 0 {
|
||||
|
@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) {
|
|||
continue
|
||||
}
|
||||
|
||||
err := f.load(filter)
|
||||
err := d.load(filter)
|
||||
if err != nil {
|
||||
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func deduplicateFilters() {
|
||||
// Deduplicate filters
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
|
||||
urls := stringutil.NewSet()
|
||||
lastIdx := 0
|
||||
|
||||
for _, filter := range filters {
|
||||
if !urls.Has(filter.URL) {
|
||||
urls.Add(filter.URL)
|
||||
filters[lastIdx] = filter
|
||||
lastIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
return filters[:lastIdx]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
func updateUniqueFilterID(filters []FilterYAML) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
|
@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 {
|
|||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func (f *Filtering) periodicallyRefreshFilters() {
|
||||
func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
const maxInterval = 1 * 60 * 60
|
||||
intval := 5 // use a dynamically increasing time interval
|
||||
for {
|
||||
isNetworkErr := false
|
||||
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
|
||||
f.refreshLock.Lock()
|
||||
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
if !isNetworkErr {
|
||||
isNetErr, ok := false, false
|
||||
if d.FiltersUpdateIntervalHours != 0 {
|
||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||
if ok && !isNetErr {
|
||||
intval = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
if isNetworkErr {
|
||||
if isNetErr {
|
||||
intval *= 2
|
||||
if intval > maxInterval {
|
||||
intval = maxInterval
|
||||
|
@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() {
|
|||
}
|
||||
}
|
||||
|
||||
// Refresh filters
|
||||
// flags: filterRefresh*
|
||||
// important:
|
||||
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||
// already going on.
|
||||
//
|
||||
// TRUE: ignore the fact that we're currently updating the filters
|
||||
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
|
||||
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
|
||||
if !important && !set {
|
||||
return 0, fmt.Errorf("filters update procedure is already running")
|
||||
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
|
||||
// sync.Mutex.TryLock.
|
||||
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
|
||||
if ok = d.refreshLock.TryLock(); !ok {
|
||||
return 0, false, ok
|
||||
}
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
f.refreshLock.Lock()
|
||||
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
return nUpdated, nil
|
||||
updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force)
|
||||
|
||||
return updated, isNetworkErr, ok
|
||||
}
|
||||
|
||||
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
|
||||
var updateFilters []filter
|
||||
// refreshFilters updates the lists and returns the number of updated ones.
|
||||
// It's safe for concurrent use, but blocks at least until the previous
|
||||
// refreshing is finished.
|
||||
func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) {
|
||||
d.refreshLock.Lock()
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
updated, _ = d.refreshFiltersIntl(block, allow, force)
|
||||
|
||||
return updated
|
||||
}
|
||||
|
||||
// listsToUpdate returns the slice of filter lists that could be updated.
|
||||
func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) {
|
||||
now := time.Now()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
for i := range *filters {
|
||||
flt := &(*filters)[i] // otherwise we will be operating on a copy
|
||||
log.Debug("checking list at index %d: %v", i, flt)
|
||||
|
||||
if !flt.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if !force {
|
||||
exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour)
|
||||
if now.Before(exp) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
toUpd = append(toUpd, FilterYAML{
|
||||
Filter: Filter{
|
||||
ID: flt.ID,
|
||||
},
|
||||
URL: flt.URL,
|
||||
Name: flt.Name,
|
||||
checksum: flt.checksum,
|
||||
})
|
||||
}
|
||||
|
||||
return toUpd
|
||||
}
|
||||
|
||||
func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) {
|
||||
var updateFlags []bool // 'true' if filter data has changed
|
||||
|
||||
now := time.Now()
|
||||
config.RLock()
|
||||
for i := range *filters {
|
||||
f := &(*filters)[i] // otherwise we will be operating on a copy
|
||||
|
||||
if !f.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
|
||||
if !force && expireTime > now.Unix() {
|
||||
continue
|
||||
}
|
||||
|
||||
var uf filter
|
||||
uf.ID = f.ID
|
||||
uf.URL = f.URL
|
||||
uf.Name = f.Name
|
||||
uf.checksum = f.checksum
|
||||
updateFilters = append(updateFilters, uf)
|
||||
}
|
||||
config.RUnlock()
|
||||
|
||||
updateFilters := d.listsToUpdate(filters, force)
|
||||
if len(updateFilters) == 0 {
|
||||
return 0, nil, nil, false
|
||||
}
|
||||
|
@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
|||
nfail := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated, err := f.update(uf)
|
||||
updated, err := d.update(uf)
|
||||
updateFlags = append(updateFlags, updated)
|
||||
if err != nil {
|
||||
nfail++
|
||||
|
@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
|||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
|
||||
config.Lock()
|
||||
d.filtersMu.Lock()
|
||||
for k := range *filters {
|
||||
f := &(*filters)[k]
|
||||
if f.ID != uf.ID || f.URL != uf.URL {
|
||||
|
@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
|||
f.checksum = uf.checksum
|
||||
updateCount++
|
||||
}
|
||||
config.Unlock()
|
||||
d.filtersMu.Unlock()
|
||||
}
|
||||
|
||||
return updateCount, updateFilters, updateFlags, false
|
||||
}
|
||||
|
||||
const (
|
||||
filterRefreshForce = 1 // ignore last file modification date
|
||||
filterRefreshAllowlists = 2 // update allow-lists
|
||||
filterRefreshBlocklists = 4 // update block-lists
|
||||
)
|
||||
|
||||
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
|
||||
// force is true, it ignores the filter.LastUpdated field value.
|
||||
// refreshFiltersIntl checks filters and updates them if necessary. If force is
|
||||
// true, it ignores the filter.LastUpdated field value.
|
||||
//
|
||||
// Algorithm:
|
||||
//
|
||||
|
@ -378,53 +372,49 @@ const (
|
|||
// that this method works only on Unix systems. On Windows, don't pass
|
||||
// files to filtering, pass the whole data.
|
||||
//
|
||||
// refreshFiltersIfNecessary returns the number of updated filters. It also
|
||||
// returns true if there was a network error and nothing could be updated.
|
||||
// refreshFiltersIntl returns the number of updated filters. It also returns
|
||||
// true if there was a network error and nothing could be updated.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): What the hell?
|
||||
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
|
||||
log.Debug("Filters: updating...")
|
||||
func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||
log.Debug("filtering: updating...")
|
||||
|
||||
updateCount := 0
|
||||
var updateFilters []filter
|
||||
var updateFlags []bool
|
||||
netError := false
|
||||
netErrorW := false
|
||||
force := false
|
||||
if (flags & filterRefreshForce) != 0 {
|
||||
force = true
|
||||
updNum := 0
|
||||
var lists []FilterYAML
|
||||
var toUpd []bool
|
||||
isNetErr := false
|
||||
|
||||
if block {
|
||||
updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force)
|
||||
}
|
||||
if (flags & filterRefreshBlocklists) != 0 {
|
||||
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
|
||||
if allow {
|
||||
updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force)
|
||||
|
||||
updNum += updNumAl
|
||||
lists = append(lists, listsAl...)
|
||||
toUpd = append(toUpd, toUpdAl...)
|
||||
isNetErr = isNetErr || isNetErrAl
|
||||
}
|
||||
if (flags & filterRefreshAllowlists) != 0 {
|
||||
updateCountW := 0
|
||||
var updateFiltersW []filter
|
||||
var updateFlagsW []bool
|
||||
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
|
||||
updateCount += updateCountW
|
||||
updateFilters = append(updateFilters, updateFiltersW...)
|
||||
updateFlags = append(updateFlags, updateFlagsW...)
|
||||
}
|
||||
if netError && netErrorW {
|
||||
if isNetErr {
|
||||
return 0, true
|
||||
}
|
||||
|
||||
if updateCount != 0 {
|
||||
enableFilters(false)
|
||||
if updNum != 0 {
|
||||
d.EnableFilters(false)
|
||||
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
for i := range lists {
|
||||
uf := &lists[i]
|
||||
updated := toUpd[i]
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
_ = os.Remove(uf.Path() + ".old")
|
||||
_ = os.Remove(uf.Path(d.DataDir) + ".old")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Filters: update finished")
|
||||
return updateCount, false
|
||||
log.Debug("filtering: update finished")
|
||||
|
||||
return updNum, false
|
||||
}
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
|
@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool {
|
|||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
|
@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
|||
if len(line) == 0 {
|
||||
//
|
||||
} else if line[0] == '!' {
|
||||
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
|
@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
|||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
func (f *Filtering) update(filter *filter) (bool, error) {
|
||||
b, err := f.updateIntl(filter)
|
||||
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||
b, err := d.updateIntl(filter)
|
||||
filter.LastUpdated = time.Now()
|
||||
if !b {
|
||||
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
|
||||
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
|
||||
if e != nil {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
|
@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) {
|
|||
return b, err
|
||||
}
|
||||
|
||||
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) {
|
||||
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
|
||||
htmlTest := true
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
|
@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in
|
|||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if sucсeeded.
|
||||
func finalizeUpdate(
|
||||
f *os.File,
|
||||
flt *filter,
|
||||
func (d *DNSFilter) finalizeUpdate(
|
||||
file *os.File,
|
||||
flt *FilterYAML,
|
||||
updated bool,
|
||||
name string,
|
||||
rnum int,
|
||||
cs uint32,
|
||||
) (err error) {
|
||||
tmpFileName := f.Name()
|
||||
tmpFileName := file.Name()
|
||||
|
||||
// Close the file before renaming it because it's required on Windows.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||
if err = f.Close(); err != nil {
|
||||
if err = file.Close(); err != nil {
|
||||
return fmt.Errorf("closing temporary file: %w", err)
|
||||
}
|
||||
|
||||
|
@ -562,9 +552,9 @@ func finalizeUpdate(
|
|||
return os.Remove(tmpFileName)
|
||||
}
|
||||
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path())
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||
|
||||
if err = os.Rename(tmpFileName, flt.Path()); err != nil {
|
||||
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
|
||||
|
@ -578,12 +568,12 @@ func finalizeUpdate(
|
|||
// processUpdate copies filter's content from src to dst and returns the name,
|
||||
// rules number, and checksum for it. It also returns the number of bytes read
|
||||
// from src.
|
||||
func (f *Filtering) processUpdate(
|
||||
func (d *DNSFilter) processUpdate(
|
||||
src io.Reader,
|
||||
dst *os.File,
|
||||
flt *filter,
|
||||
flt *FilterYAML,
|
||||
) (name string, rnum int, cs uint32, n int, err error) {
|
||||
if n, err = f.read(src, dst, flt); err != nil {
|
||||
if n, err = d.read(src, dst, flt); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
|
@ -591,14 +581,14 @@ func (f *Filtering) processUpdate(
|
|||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
rnum, cs, name = f.parseFilterContents(dst)
|
||||
rnum, cs, name = d.parseFilterContents(dst)
|
||||
|
||||
return name, rnum, cs, n, nil
|
||||
}
|
||||
|
||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||
// the actual update has been performed.
|
||||
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
|
||||
|
||||
var name string
|
||||
|
@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
|||
var cs uint32
|
||||
|
||||
var tmpFile *os.File
|
||||
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "")
|
||||
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
ok = ok && err == nil
|
||||
if ok {
|
||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||
|
@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
|||
r = file
|
||||
} else {
|
||||
var resp *http.Response
|
||||
resp, err = Context.client.Get(flt.URL)
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
if err != nil {
|
||||
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
|
||||
|
||||
|
@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
|||
r = resp.Body
|
||||
}
|
||||
|
||||
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt)
|
||||
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
|
||||
|
||||
return cs != flt.checksum, err
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (f *Filtering) load(filter *filter) (err error) {
|
||||
filterFilePath := filter.Path()
|
||||
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
filterFilePath := filter.Path(d.DataDir)
|
||||
|
||||
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
|
||||
|
||||
file, err := os.Open(filterFilePath)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
|
@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) {
|
|||
|
||||
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
||||
|
||||
rulesCount, checksum, _ := f.parseFilterContents(file)
|
||||
rulesCount, checksum, _ := d.parseFilterContents(file)
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
|
@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *filter) unload() {
|
||||
filter.RulesCount = 0
|
||||
filter.checksum = 0
|
||||
func (d *DNSFilter) EnableFilters(async bool) {
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
d.enableFiltersLocked(async)
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
func enableFilters(async bool) {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
enableFiltersLocked(async)
|
||||
}
|
||||
|
||||
func enableFiltersLocked(async bool) {
|
||||
filters := []filtering.Filter{{
|
||||
ID: filtering.CustomListID,
|
||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||
func (d *DNSFilter) enableFiltersLocked(async bool) {
|
||||
filters := []Filter{{
|
||||
ID: CustomListID,
|
||||
Data: []byte(strings.Join(d.UserRules, "\n")),
|
||||
}}
|
||||
|
||||
for _, filter := range config.Filters {
|
||||
for _, filter := range d.Filters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
filters = append(filters, filtering.Filter{
|
||||
filters = append(filters, Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
FilePath: filter.Path(d.DataDir),
|
||||
})
|
||||
}
|
||||
|
||||
var allowFilters []filtering.Filter
|
||||
for _, filter := range config.WhitelistFilters {
|
||||
var allowFilters []Filter
|
||||
for _, filter := range d.WhitelistFilters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
allowFilters = append(allowFilters, filtering.Filter{
|
||||
allowFilters = append(allowFilters, Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
FilePath: filter.Path(d.DataDir),
|
||||
})
|
||||
}
|
||||
|
||||
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
||||
if err := d.SetFilters(filters, allowFilters, async); err != nil {
|
||||
log.Debug("enabling filters: %s", err)
|
||||
}
|
||||
|
||||
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled)
|
||||
d.SetEnabled(d.FilteringEnabled)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
|
@ -51,15 +51,17 @@ func TestFilters(t *testing.T) {
|
|||
|
||||
l := testStartFilterListener(t, &fltContent)
|
||||
|
||||
Context = homeContext{
|
||||
workDir: t.TempDir(),
|
||||
client: &http.Client{
|
||||
tempDir := t.TempDir()
|
||||
|
||||
filters, err := New(&Config{
|
||||
DataDir: tempDir,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
Context.filters.Init()
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &filter{
|
||||
f := &FilterYAML{
|
||||
URL: (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: (&netutil.IPPort{
|
||||
|
@ -71,21 +73,22 @@ func TestFilters(t *testing.T) {
|
|||
}
|
||||
|
||||
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
||||
ok, err := Context.filters.update(f)
|
||||
var ok bool
|
||||
ok, err = filters.update(f)
|
||||
require.NoError(t, err)
|
||||
want(t, ok)
|
||||
|
||||
assert.Equal(t, wantRulesCount, f.RulesCount)
|
||||
|
||||
var dir []fs.DirEntry
|
||||
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir))
|
||||
dir, err = os.ReadDir(filepath.Join(tempDir, filterDir))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, dir, 1)
|
||||
|
||||
require.FileExists(t, f.Path())
|
||||
require.FileExists(t, f.Path(tempDir))
|
||||
|
||||
err = Context.filters.load(f)
|
||||
err = filters.load(f)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -105,11 +108,9 @@ func TestFilters(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("load_unload", func(t *testing.T) {
|
||||
err := Context.filters.load(f)
|
||||
err = filters.load(f)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.unload()
|
||||
})
|
||||
|
||||
require.NoError(t, os.Remove(f.Path()))
|
||||
}
|
|
@ -6,7 +6,10 @@ import (
|
|||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
@ -24,6 +27,7 @@ import (
|
|||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
|
@ -69,8 +73,13 @@ type Config struct {
|
|||
// enabled is used to be returned within Settings.
|
||||
//
|
||||
// It is of type uint32 to be accessed by atomic.
|
||||
//
|
||||
// TODO(e.burkov): Use atomic.Bool in Go 1.19.
|
||||
enabled uint32
|
||||
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
|
@ -98,6 +107,24 @@ type Config struct {
|
|||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
|
||||
// HTTPClient is the client to use for updating the remote filters.
|
||||
HTTPClient *http.Client `yaml:"-"`
|
||||
|
||||
// DataDir is used to store filters' contents.
|
||||
DataDir string `yaml:"-"`
|
||||
|
||||
// filtersMu protects filter lists.
|
||||
filtersMu *sync.RWMutex
|
||||
|
||||
// Filters are the blocking filter lists.
|
||||
Filters []FilterYAML `yaml:"-"`
|
||||
|
||||
// WhitelistFilters are the allowing filter lists.
|
||||
WhitelistFilters []FilterYAML `yaml:"-"`
|
||||
|
||||
// UserRules is the global list of custom rules.
|
||||
UserRules []string `yaml:"-"`
|
||||
}
|
||||
|
||||
// LookupStats store stats collected during safebrowsing or parental checks
|
||||
|
@ -130,8 +157,10 @@ type hostChecker struct {
|
|||
type DNSFilter struct {
|
||||
rulesStorage *filterlist.RuleStorage
|
||||
filteringEngine *urlfilter.DNSEngine
|
||||
|
||||
rulesStorageAllow *filterlist.RuleStorage
|
||||
filteringEngineAllow *urlfilter.DNSEngine
|
||||
|
||||
engineLock sync.RWMutex
|
||||
|
||||
parentalServer string // access via methods
|
||||
|
@ -156,6 +185,12 @@ type DNSFilter struct {
|
|||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||
resolver Resolver
|
||||
|
||||
refreshLock *sync.Mutex
|
||||
|
||||
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||
// filter list.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
||||
|
@ -168,7 +203,7 @@ type Filter struct {
|
|||
Data []byte `yaml:"-"`
|
||||
|
||||
// ID is automatically assigned when filter is added using nextFilterID.
|
||||
ID int64
|
||||
ID int64 `yaml:"id"`
|
||||
}
|
||||
|
||||
// Reason holds an enum detailing why it was filtered or not filtered
|
||||
|
@ -245,15 +280,7 @@ func (r Reason) String() string {
|
|||
}
|
||||
|
||||
// In returns true if reasons include r.
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) {
|
||||
for _, reason := range reasons {
|
||||
if r == reason {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) }
|
||||
|
||||
// SetEnabled sets the status of the *DNSFilter.
|
||||
func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||
|
@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
|
|||
if enabled {
|
||||
i = 1
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&d.enabled, uint32(i))
|
||||
}
|
||||
|
||||
|
@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
|||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
func() {
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
*c = d.Config
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
}()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
c.Filters = slices.Clone(d.Filters)
|
||||
c.WhitelistFilters = slices.Clone(d.WhitelistFilters)
|
||||
c.UserRules = slices.Clone(d.UserRules)
|
||||
}
|
||||
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
|
@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
|||
}
|
||||
|
||||
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
|
||||
defer d.filtersInitializerLock.Unlock()
|
||||
|
||||
// remove all pending tasks
|
||||
stop := false
|
||||
for !stop {
|
||||
|
@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
|||
}
|
||||
|
||||
d.filtersInitializerChan <- params
|
||||
d.filtersInitializerLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() {
|
|||
func (d *DNSFilter) Close() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
d.reset()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) reset() {
|
||||
var err error
|
||||
|
||||
if d.rulesStorage != nil {
|
||||
err = d.rulesStorage.Close()
|
||||
if err != nil {
|
||||
if err := d.rulesStorage.Close(); err != nil {
|
||||
log.Error("filtering: rulesStorage.Close: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if d.rulesStorageAllow != nil {
|
||||
err = d.rulesStorageAllow.Close()
|
||||
if err != nil {
|
||||
if err := d.rulesStorageAllow.Close(); err != nil {
|
||||
log.Error("filtering: rulesStorageAllow.Close: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -885,12 +920,14 @@ func InitModule() {
|
|||
initBlockedServices()
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
|
||||
}
|
||||
if c != nil {
|
||||
|
||||
d.safebrowsingCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
|
@ -905,9 +942,8 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
|||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if c.CustomResolver != nil {
|
||||
d.resolver = c.CustomResolver
|
||||
}
|
||||
if r := c.CustomResolver; r != nil {
|
||||
d.resolver = r
|
||||
}
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
|
@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
|||
name: "safe search",
|
||||
}}
|
||||
|
||||
err := d.initSecurityServices()
|
||||
if err != nil {
|
||||
log.Error("filtering: initialize services: %s", err)
|
||||
defer func() { err = errors.Annotate(err, "filtering: %w") }()
|
||||
|
||||
return nil
|
||||
err = d.initSecurityServices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing services: %s", err)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
d.Config = *c
|
||||
d.filtersMu = &sync.RWMutex{}
|
||||
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
log.Error("rewrites: preparing: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
if !BlockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service %q", s)
|
||||
|
||||
continue
|
||||
}
|
||||
bsvcs = append(bsvcs, s)
|
||||
|
@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
|||
if blockFilters != nil {
|
||||
err = d.initFiltering(nil, blockFilters)
|
||||
if err != nil {
|
||||
log.Error("Can't initialize filtering subsystem: %s", err)
|
||||
d.Close()
|
||||
return nil
|
||||
|
||||
return nil, fmt.Errorf("initializing filtering subsystem: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return d
|
||||
_ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755)
|
||||
|
||||
d.loadFilters(d.Filters)
|
||||
d.loadFilters(d.WhitelistFilters)
|
||||
|
||||
d.Filters = deduplicateFilters(d.Filters)
|
||||
d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters)
|
||||
|
||||
updateUniqueFilterID(d.Filters)
|
||||
updateUniqueFilterID(d.WhitelistFilters)
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Start - start the module:
|
||||
|
@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() {
|
|||
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
||||
go d.filtersInitializer()
|
||||
|
||||
if d.Config.HTTPRegister != nil { // for tests
|
||||
d.registerSecurityHandlers()
|
||||
d.registerRewritesHandlers()
|
||||
d.registerBlockedServicesHandlers()
|
||||
}
|
||||
d.RegisterFilteringHandlers()
|
||||
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go d.periodicallyRefreshFilters()
|
||||
}
|
||||
|
|
|
@ -26,10 +26,6 @@ const (
|
|||
pcBlocked = "pornhub.com"
|
||||
)
|
||||
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
|
||||
// Helpers.
|
||||
|
||||
func purgeCaches(d *DNSFilter) {
|
||||
|
@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) {
|
|||
}
|
||||
}
|
||||
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||
setts = Settings{
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
|
||||
setts = &Settings{
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
|||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
} else {
|
||||
// It must not be nil.
|
||||
c = &Config{}
|
||||
}
|
||||
d := New(c, filters)
|
||||
purgeCaches(d)
|
||||
f, err := New(c, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
purgeCaches(f)
|
||||
|
||||
return f, setts
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
||||
func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname)
|
||||
|
||||
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
|
||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
||||
res, err := d.CheckHost(hostname, qtype, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname, err)
|
||||
require.NotEmpty(t, res.Rules, "host %q", hostname)
|
||||
|
||||
|
@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
|
|||
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname)
|
||||
|
||||
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
|
||||
|
@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) {
|
|||
filters := []Filter{{
|
||||
ID: 0, Data: []byte(text),
|
||||
}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
||||
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
|
||||
d.checkMatchEmpty(t, "subdomain.google.com")
|
||||
d.checkMatchEmpty(t, "example.org")
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts)
|
||||
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts)
|
||||
d.checkMatchEmpty(t, "subdomain.google.com", setts)
|
||||
d.checkMatchEmpty(t, "example.org", setts)
|
||||
|
||||
// IPv4 match.
|
||||
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA)
|
||||
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts)
|
||||
|
||||
// Empty IPv6.
|
||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) {
|
|||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// IPv6 match.
|
||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts)
|
||||
|
||||
// Empty IPv4.
|
||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
|||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// Two IPv4, both must be returned.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
|||
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
|
||||
|
||||
// One IPv6 address.
|
||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeAAAA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) {
|
|||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
|
||||
// Cached result.
|
||||
d.safeBrowsingServer = "127.0.0.1"
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) {
|
|||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) {
|
|||
// Safe Search.
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(t, ok)
|
||||
|
@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
|||
"www.yandex.com",
|
||||
} {
|
||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
|||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
|
@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
|||
"www.google.je",
|
||||
} {
|
||||
t.Run(host, func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, setts := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const domain = "yandex.ru"
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
|||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
|
@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
|||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const domain = "www.google.ru"
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
d.resolver = resolver
|
||||
|
||||
|
@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
|
@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) {
|
|||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatch(t, pcBlocked, setts)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||
|
||||
d.checkMatch(t, "www."+pcBlocked)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
d.checkMatch(t, "www."+pcBlocked, setts)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, "api.jquery.com", setts)
|
||||
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatch(t, pcBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
}
|
||||
|
||||
// Filtering.
|
||||
|
@ -679,10 +680,10 @@ func TestMatching(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||
|
@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) {
|
|||
whiteFilters := []Filter{{
|
||||
ID: 0, Data: []byte(whiteRules),
|
||||
}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
|
||||
err := d.SetFilters(filters, whiteFilters, false)
|
||||
require.NoError(t, err)
|
||||
|
@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) {
|
|||
t.Cleanup(d.Close)
|
||||
|
||||
// Matched by white filter.
|
||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
||||
res, err := d.CheckHost("host1", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) {
|
|||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||
|
||||
// Not matched by white filter, but matched by block filter.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) {
|
|||
}
|
||||
|
||||
func TestClientSettings(t *testing.T) {
|
||||
d := newForTest(t,
|
||||
d, setts := newForTest(t,
|
||||
&Config{
|
||||
ParentalEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
|
@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) {
|
|||
return func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if before {
|
||||
|
@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) {
|
|||
t.Run(tc.name, makeTester(tc, tc.before))
|
||||
}
|
||||
|
||||
applyClientSettings(&setts)
|
||||
applyClientSettings(setts)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, makeTester(tc, !tc.before))
|
||||
|
@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) {
|
|||
// Benchmarks.
|
||||
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
|
@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
|
@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
|
@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
|
|
|
@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
|||
// 1. A and AAAA > CNAME
|
||||
// 2. wildcard > exact
|
||||
// 3. lower level wildcard > higher level wildcard
|
||||
//
|
||||
// TODO(a.garipov): Replace with slices.Sort.
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Len implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Swap implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Less implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
|
||||
ith, jth := a[i], a[j]
|
||||
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
|
||||
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
|
||||
return false
|
||||
}
|
||||
|
||||
if isWildcard(a[i].Domain) {
|
||||
if !isWildcard(a[j].Domain) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if isWildcard(a[j].Domain) {
|
||||
return true
|
||||
}
|
||||
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
|
||||
return jw
|
||||
}
|
||||
|
||||
// Both are wildcards.
|
||||
return len(a[i].Domain) > len(a[j].Domain)
|
||||
// Both are either wildcards or not.
|
||||
return len(ith.Domain) > len(jth.Domain)
|
||||
}
|
||||
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
|
@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
|||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerRewritesHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
|
@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
|
@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
|
@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
|
|
|
@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
|||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerSecurityHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := aghtest.NewErrorUpstream()
|
||||
|
@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSBPC(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const hostname = "example.org"
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -23,10 +22,9 @@ import (
|
|||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
// dataDir is the name of a directory under the working one to store some
|
||||
// persistent data.
|
||||
const dataDir = "data"
|
||||
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
|
@ -108,8 +106,15 @@ type configuration struct {
|
|||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
|
||||
Filters []filter `yaml:"filters"`
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
||||
// config used in the filtering module at the startup. Afterwards it's
|
||||
// cloned from the filtering module back here.
|
||||
//
|
||||
// TODO(e.burkov): Move all the filtering configuration fields into the
|
||||
// only configuration subsection covering the changes with a single
|
||||
// migration.
|
||||
Filters []filtering.FilterYAML `yaml:"filters"`
|
||||
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
@ -145,9 +150,7 @@ type dnsConfig struct {
|
|||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
DnsfilterConf filtering.Config `yaml:",inline"`
|
||||
DnsfilterConf *filtering.Config `yaml:",inline"`
|
||||
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
|
@ -198,10 +201,15 @@ var config = &configuration{
|
|||
BindHost: net.IP{0, 0, 0, 0},
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
WebSessionTTLHours: 30 * 24,
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []net.IP{{0, 0, 0, 0}},
|
||||
Port: defaultPortDNS,
|
||||
StatsInterval: 1,
|
||||
QueryLogEnabled: true,
|
||||
QueryLogFileEnabled: true,
|
||||
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
|
||||
QueryLogMemSize: 1000,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of filtering features
|
||||
BlockingMode: dnsforward.BlockingModeDefault,
|
||||
|
@ -222,8 +230,14 @@ var config = &configuration{
|
|||
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
||||
MaxGoroutines: 300,
|
||||
},
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
DnsfilterConf: &filtering.Config{
|
||||
SafeBrowsingCacheSize: 1 * 1024 * 1024,
|
||||
SafeSearchCacheSize: 1 * 1024 * 1024,
|
||||
ParentalCacheSize: 1 * 1024 * 1024,
|
||||
CacheTime: 30,
|
||||
FilteringEnabled: true,
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
|
@ -232,8 +246,26 @@ var config = &configuration{
|
|||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||
PortDNSOverQUIC: defaultPortQUIC,
|
||||
},
|
||||
Filters: []filtering.FilterYAML{{
|
||||
Filter: filtering.Filter{ID: 1},
|
||||
Enabled: true,
|
||||
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
|
||||
Name: "AdGuard DNS filter",
|
||||
}, {
|
||||
Filter: filtering.Filter{ID: 2},
|
||||
Enabled: false,
|
||||
URL: "https://adaway.org/hosts.txt",
|
||||
Name: "AdAway Default Blocklist",
|
||||
}},
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
|
||||
},
|
||||
Conf6: dhcpd.V6ServerConf{
|
||||
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||
},
|
||||
},
|
||||
Clients: &clientsConfig{
|
||||
Sources: &clientSourcesConf{
|
||||
|
@ -255,31 +287,6 @@ var config = &configuration{
|
|||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// initConfig initializes default configuration for the current OS&ARCH
|
||||
func initConfig() {
|
||||
config.WebSessionTTLHours = 30 * 24
|
||||
|
||||
config.DNS.QueryLogEnabled = true
|
||||
config.DNS.QueryLogFileEnabled = true
|
||||
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
|
||||
config.DNS.QueryLogMemSize = 1000
|
||||
|
||||
config.DNS.CacheSize = 4 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.CacheTime = 30
|
||||
config.Filters = defaultFilters()
|
||||
|
||||
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
|
||||
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
|
||||
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
|
||||
config.BetaBindPort = 3001
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
func (c *configuration) getConfigFilename() string {
|
||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
||||
|
@ -348,8 +355,8 @@ func parseConfig() (err error) {
|
|||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
|
||||
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||
|
@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
|
|||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
c := filtering.Config{}
|
||||
Context.dnsFilter.WriteDiskConfig(&c)
|
||||
config.DNS.DnsfilterConf = c
|
||||
if Context.filters != nil {
|
||||
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
|
||||
config.Filters = config.DNS.DnsfilterConf.Filters
|
||||
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
|
||||
config.UserRules = config.DNS.DnsfilterConf.UserRules
|
||||
}
|
||||
|
||||
if s := Context.dnsServer; s != nil {
|
||||
|
|
|
@ -291,7 +291,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
|||
}
|
||||
|
||||
httpsURL := &url.URL{
|
||||
Scheme: schemeHTTPS,
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: hostPort,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
|
@ -307,7 +307,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
|||
//
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
||||
originURL := &url.URL{
|
||||
Scheme: schemeHTTP,
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Host: r.Host,
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
||||
|
|
|
@ -31,7 +31,10 @@ const (
|
|||
|
||||
// Called by other modules when configuration is changed
|
||||
func onConfigModified() {
|
||||
_ = config.write()
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Error("writing config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
|
@ -71,11 +74,11 @@ func initDNSServer() (err error) {
|
|||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
filterConf := config.DNS.DnsfilterConf
|
||||
filterConf.EtcHosts = Context.etcHosts
|
||||
filterConf.ConfigModified = onConfigModified
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
Context.dnsFilter = filtering.New(&filterConf, nil)
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
|
@ -83,13 +86,10 @@ func initDNSServer() (err error) {
|
|||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
var n *net.IPNet
|
||||
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = n
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
|
@ -101,15 +101,13 @@ func initDNSServer() (err error) {
|
|||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.dnsFilter,
|
||||
DNSFilter: Context.filters,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
DHCPServer: Context.dhcpServer,
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||
|
@ -143,7 +141,6 @@ func initDNSServer() (err error) {
|
|||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -335,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) {
|
|||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
|
||||
Context.filters.ApplyBlockedServices(setts, nil)
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
if clientIP == nil {
|
||||
return
|
||||
|
@ -349,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
|||
if !ok {
|
||||
c, ok = Context.clients.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
|
||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
|
||||
log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
|
@ -381,7 +381,7 @@ func startDNSServer() error {
|
|||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
enableFiltersLocked(false)
|
||||
Context.filters.EnableFilters(false)
|
||||
|
||||
Context.clients.Start()
|
||||
|
||||
|
@ -390,7 +390,6 @@ func startDNSServer() error {
|
|||
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
||||
}
|
||||
|
||||
Context.dnsFilter.Start()
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
@ -449,10 +448,7 @@ func closeDNSServer() {
|
|||
Context.dnsServer = nil
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
Context.dnsFilter.Close()
|
||||
Context.dnsFilter = nil
|
||||
}
|
||||
Context.filters.Close()
|
||||
|
||||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
|
@ -469,7 +465,5 @@ func closeDNSServer() {
|
|||
Context.queryLog = nil
|
||||
}
|
||||
|
||||
Context.filters.Close()
|
||||
|
||||
log.Debug("Closed all DNS modules")
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
|
@ -33,6 +34,7 @@ import (
|
|||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"golang.org/x/exp/slices"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
|
@ -52,10 +54,9 @@ type homeContext struct {
|
|||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
|
@ -140,7 +141,12 @@ func setupContext(args options) {
|
|||
checkPermissions()
|
||||
}
|
||||
|
||||
initConfig()
|
||||
switch version.Channel() {
|
||||
case version.ChannelEdge, version.ChannelDevelopment:
|
||||
config.BetaBindPort = 3001
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
Context.tlsRoots = LoadSystemRootCAs()
|
||||
Context.transport = &http.Transport{
|
||||
|
@ -265,6 +271,14 @@ func setupHostsContainer() (err error) {
|
|||
}
|
||||
|
||||
func setupConfig(args options) (err error) {
|
||||
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
|
||||
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
|
||||
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
|
||||
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
|
||||
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
|
||||
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
config.DNS.DnsfilterConf.HTTPClient = Context.client
|
||||
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
|
@ -384,8 +398,6 @@ func fatalOnError(err error) {
|
|||
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(args options, clientBuildFS fs.FS) {
|
||||
var err error
|
||||
|
||||
// configure config filename
|
||||
initConfigFilename(args)
|
||||
|
||||
|
@ -404,7 +416,7 @@ func run(args options, clientBuildFS fs.FS) {
|
|||
|
||||
setupContext(args)
|
||||
|
||||
err = configureOS(config)
|
||||
err := configureOS(config)
|
||||
fatalOnError(err)
|
||||
|
||||
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
|
||||
|
@ -763,12 +775,12 @@ func printHTTPAddresses(proto string) {
|
|||
}
|
||||
|
||||
port := config.BindPort
|
||||
if proto == schemeHTTPS {
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||
if proto == schemeHTTPS && tlsConf.ServerName != "" {
|
||||
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
|
||||
|
||||
return
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
|||
case dnsProtoHTTPS:
|
||||
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
u := &url.URL{
|
||||
Scheme: schemeHTTPS,
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: d.ServerName,
|
||||
Path: path.Join("/dns-query", clientID),
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
@ -277,7 +278,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
|
|||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.
|
||||
AdGuard Home is now available at the following addresses:`)
|
||||
printHTTPAddresses(schemeHTTP)
|
||||
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -160,7 +161,7 @@ func assertEqualExcept(t *testing.T, oldConf, newConf yobj, oldKeys, newKeys []s
|
|||
}
|
||||
|
||||
func testDiskConf(schemaVersion int) (diskConf yobj) {
|
||||
filters := []filter{{
|
||||
filters := []filtering.FilterYAML{{
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
RulesCount: 100,
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
@ -19,12 +20,6 @@ import (
|
|||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
const (
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
)
|
||||
|
||||
const (
|
||||
// readTimeout is the maximum duration for reading the entire request,
|
||||
// including the body.
|
||||
|
@ -166,7 +161,7 @@ func (web *Web) Start() {
|
|||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.shutdown {
|
||||
printHTTPAddresses(schemeHTTP)
|
||||
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||
errs := make(chan error, 2)
|
||||
|
||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||
|
@ -286,7 +281,7 @@ func (web *Web) tlsServerLoop() {
|
|||
WriteTimeout: web.conf.WriteTimeout,
|
||||
}
|
||||
|
||||
printHTTPAddresses(schemeHTTPS)
|
||||
printHTTPAddresses(aghhttp.SchemeHTTPS)
|
||||
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
|
|
Loading…
Reference in New Issue