Pull request: 4871 imp filtering
Merge in DNS/adguard-home from 4871-imp-filtering to master
Closes #4871.
Squashed commit of the following:
commit 618e7c558447703c114332708c94ef1b34362cf9
Merge: 41ff8ab7 11e4f091
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 19:27:08 2022 +0300
Merge branch 'master' into 4871-imp-filtering
commit 41ff8ab755a87170e7334dedcae00f01dcca238a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 19:26:11 2022 +0300
filtering: imp code, log
commit e4ae1d1788406ffd7ef0fcc6df896a22b0c2db37
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Sep 22 14:11:07 2022 +0300
filtering: move handlers into single func
commit f7a340b4c10980f512ae935a156f02b0133a1627
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Sep 21 19:21:09 2022 +0300
all: imp code
commit e064bf4d3de0283e4bda2aaf5b9822bb8a08f4a6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 20:12:16 2022 +0300
all: imp name
commit e7eda3905762f0821e1be1ac3cf77e0ecbedeff4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:51:23 2022 +0300
all: finally get rid of filtering
commit 188550d873e625cc2951583bb3a2eaad036745f5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:36:03 2022 +0300
filtering: merge refresh
commit e54ed9c7952b17e66b790c835269b28fbc26f9ca
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 17:16:23 2022 +0300
filtering: merge filters
commit 32da31b754a319487d5f9d5e81e607d349b90180
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 14:48:13 2022 +0300
filtering: imp docs
commit 43b0cafa7a27bb9b620c2ba50ccdddcf32cfcecc
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Sep 20 14:38:04 2022 +0300
all: imp code
commit 253a2ea6c92815d364546e34d631e406dd604644
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 20:43:15 2022 +0300
filtering: rm important flag
commit 1b87f08f946389d410f13412c7e486290d5e752d
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 17:05:40 2022 +0300
all: move filtering to the package
commit daa13499f1dd4fe475c4b75769e34f1eb0915bdf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Mon Sep 19 15:13:55 2022 +0300
all: finish merging
commit d6db75eb2e1f23528e9200ea51507eb793eefa3c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Fri Sep 16 18:18:14 2022 +0300
all: continue merging
commit 45b4c484deb7198a469aa18d719bb9dbe81e5d22
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Sep 14 15:44:22 2022 +0300
all: merge filtering types
This commit is contained in:
parent
11e4f09165
commit
47c9c946a3
|
@ -9,6 +9,12 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HTTP scheme constants.
|
||||||
|
const (
|
||||||
|
SchemeHTTP = "http"
|
||||||
|
SchemeHTTPS = "https"
|
||||||
|
)
|
||||||
|
|
||||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||||
// method.
|
// method.
|
||||||
//
|
//
|
||||||
|
|
|
@ -67,10 +67,11 @@ func createTestServer(
|
||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f := filtering.New(filterConf, filters)
|
f, err := filtering.New(filterConf, filters)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
|
|
||||||
var err error
|
|
||||||
s, err = NewServer(DNSCreateParams{
|
s, err = NewServer(DNSCreateParams{
|
||||||
DHCPServer: testDHCP,
|
DHCPServer: testDHCP,
|
||||||
DNSFilter: f,
|
DNSFilter: f,
|
||||||
|
@ -774,7 +775,9 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||||
Data: []byte(rules),
|
Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f := filtering.New(&filtering.Config{}, filters)
|
f, err := filtering.New(&filtering.Config{}, filters)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
DHCPServer: testDHCP,
|
DHCPServer: testDHCP,
|
||||||
DNSFilter: f,
|
DNSFilter: f,
|
||||||
|
@ -906,7 +909,9 @@ func TestRewrite(t *testing.T) {
|
||||||
Type: dns.TypeCNAME,
|
Type: dns.TypeCNAME,
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
f := filtering.New(c, nil)
|
f, err := filtering.New(c, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
|
@ -1021,19 +1026,14 @@ var testDHCP = &dhcpd.MockInterface{
|
||||||
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
|
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (*testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
|
||||||
// return []*dhcpd.Lease{{
|
|
||||||
// IP: net.IP{192, 168, 12, 34},
|
|
||||||
// HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
|
||||||
// Hostname: "myhost",
|
|
||||||
// }}
|
|
||||||
// }
|
|
||||||
|
|
||||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||||
const localDomain = "lan"
|
const localDomain = "lan"
|
||||||
|
|
||||||
|
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
DNSFilter: flt,
|
||||||
DHCPServer: testDHCP,
|
DHCPServer: testDHCP,
|
||||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||||
LocalDomain: localDomain,
|
LocalDomain: localDomain,
|
||||||
|
@ -1100,9 +1100,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||||
})
|
})
|
||||||
|
|
||||||
flt := filtering.New(&filtering.Config{
|
flt, err := filtering.New(&filtering.Config{
|
||||||
EtcHosts: hc,
|
EtcHosts: hc,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
flt.SetEnabled(true)
|
flt.SetEnabled(true)
|
||||||
|
|
||||||
var s *Server
|
var s *Server
|
||||||
|
|
|
@ -35,7 +35,8 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||||
ID: 0, Data: []byte(rules),
|
ID: 0, Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f := filtering.New(&filtering.Config{}, filters)
|
f, err := filtering.New(&filtering.Config{}, filters)
|
||||||
|
require.NoError(t, err)
|
||||||
f.SetEnabled(true)
|
f.SetEnabled(true)
|
||||||
|
|
||||||
s, err := NewServer(DNSCreateParams{
|
s, err := NewServer(DNSCreateParams{
|
||||||
|
|
|
@ -421,31 +421,34 @@ func initBlockedServices() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||||
func BlockedSvcKnown(s string) bool {
|
func BlockedSvcKnown(s string) (ok bool) {
|
||||||
_, ok := serviceRules[s]
|
_, ok = serviceRules[s]
|
||||||
|
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) {
|
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
|
||||||
setts.ServicesRules = []ServiceEntry{}
|
setts.ServicesRules = []ServiceEntry{}
|
||||||
if global {
|
if list == nil {
|
||||||
d.confLock.RLock()
|
d.confLock.RLock()
|
||||||
defer d.confLock.RUnlock()
|
defer d.confLock.RUnlock()
|
||||||
|
|
||||||
list = d.Config.BlockedServices
|
list = d.Config.BlockedServices
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range list {
|
for _, name := range list {
|
||||||
rules, ok := serviceRules[name]
|
rules, ok := serviceRules[name]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Error("unknown service name: %s", name)
|
log.Error("unknown service name: %s", name)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
s := ServiceEntry{}
|
setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{
|
||||||
s.Name = name
|
Name: name,
|
||||||
s.Rules = rules
|
Rules: rules,
|
||||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||||
|
|
||||||
d.ConfigModified()
|
d.ConfigModified()
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerBlockedServicesHandlers - register HTTP handlers
|
|
||||||
func (d *DNSFilter) registerBlockedServicesHandlers() {
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package home
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -34,7 +34,7 @@ func validateFilterURL(urlStr string) (err error) {
|
||||||
return fmt.Errorf("checking filter url: %w", err)
|
return fmt.Errorf("checking filter url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
|
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||||
return fmt.Errorf("checking filter url: invalid scheme %q", s)
|
return fmt.Errorf("checking filter url: invalid scheme %q", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ type filterAddJSON struct {
|
||||||
Whitelist bool `json:"whitelist"`
|
Whitelist bool `json:"whitelist"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||||
fj := filterAddJSON{}
|
fj := filterAddJSON{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -65,14 +65,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for duplicates
|
// Check for duplicates
|
||||||
if filterExists(fj.URL) {
|
if d.filterExists(fj.URL) {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set necessary properties
|
// Set necessary properties
|
||||||
filt := filter{
|
filt := FilterYAML{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
URL: fj.URL,
|
URL: fj.URL,
|
||||||
Name: fj.Name,
|
Name: fj.Name,
|
||||||
|
@ -81,7 +81,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||||
filt.ID = assignUniqueFilterID()
|
filt.ID = assignUniqueFilterID()
|
||||||
|
|
||||||
// Download the filter contents
|
// Download the filter contents
|
||||||
ok, err := f.update(&filt)
|
ok, err := d.update(&filt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(
|
aghhttp.Error(
|
||||||
r,
|
r,
|
||||||
|
@ -109,14 +109,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||||
|
|
||||||
// URL is assumed valid so append it to filters, update config, write new
|
// URL is assumed valid so append it to filters, update config, write new
|
||||||
// file and reload it to engines.
|
// file and reload it to engines.
|
||||||
if !filterAdd(filt) {
|
if !d.filterAdd(filt) {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onConfigModified()
|
d.ConfigModified()
|
||||||
enableFilters(true)
|
d.EnableFilters(true)
|
||||||
|
|
||||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -124,7 +124,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||||
type request struct {
|
type request struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
Whitelist bool `json:"whitelist"`
|
Whitelist bool `json:"whitelist"`
|
||||||
|
@ -138,23 +138,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Lock()
|
d.filtersMu.Lock()
|
||||||
filters := &config.Filters
|
filters := &d.Filters
|
||||||
if req.Whitelist {
|
if req.Whitelist {
|
||||||
filters = &config.WhitelistFilters
|
filters = &d.WhitelistFilters
|
||||||
}
|
}
|
||||||
|
|
||||||
var deleted filter
|
var deleted FilterYAML
|
||||||
var newFilters []filter
|
var newFilters []FilterYAML
|
||||||
for _, f := range *filters {
|
for _, flt := range *filters {
|
||||||
if f.URL != req.URL {
|
if flt.URL != req.URL {
|
||||||
newFilters = append(newFilters, f)
|
newFilters = append(newFilters, flt)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted = f
|
deleted = flt
|
||||||
path := f.Path()
|
path := flt.Path(d.DataDir)
|
||||||
err = os.Rename(path, path+".old")
|
err = os.Rename(path, path+".old")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("deleting filter %q: %s", path, err)
|
log.Error("deleting filter %q: %s", path, err)
|
||||||
|
@ -162,10 +162,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||||
}
|
}
|
||||||
|
|
||||||
*filters = newFilters
|
*filters = newFilters
|
||||||
config.Unlock()
|
d.filtersMu.Unlock()
|
||||||
|
|
||||||
onConfigModified()
|
d.ConfigModified()
|
||||||
enableFilters(true)
|
d.EnableFilters(true)
|
||||||
|
|
||||||
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
|
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
|
||||||
// necessary, but will require the additional complicated code to run
|
// necessary, but will require the additional complicated code to run
|
||||||
|
@ -191,55 +191,51 @@ type filterURLReq struct {
|
||||||
Whitelist bool `json:"whitelist"`
|
Whitelist bool `json:"whitelist"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||||
fj := filterURLReq{}
|
fj := filterURLReq{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if fj.Data == nil {
|
if fj.Data == nil {
|
||||||
err = errors.Error("data cannot be null")
|
aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateFilterURL(fj.Data.URL)
|
err = validateFilterURL(fj.Data.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("invalid url: %s", err)
|
aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filt := filter{
|
filt := FilterYAML{
|
||||||
Enabled: fj.Data.Enabled,
|
Enabled: fj.Data.Enabled,
|
||||||
Name: fj.Data.Name,
|
Name: fj.Data.Name,
|
||||||
URL: fj.Data.URL,
|
URL: fj.Data.URL,
|
||||||
}
|
}
|
||||||
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
status := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||||
if (status & statusFound) == 0 {
|
if (status & statusFound) == 0 {
|
||||||
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
|
aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (status & statusURLExists) != 0 {
|
if (status & statusURLExists) != 0 {
|
||||||
http.Error(w, "URL already exists", http.StatusBadRequest)
|
aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onConfigModified()
|
d.ConfigModified()
|
||||||
|
|
||||||
restart := (status & statusEnabledChanged) != 0
|
restart := (status & statusEnabledChanged) != 0
|
||||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||||
// download new filter and apply its rules
|
// download new filter and apply its rules.
|
||||||
flags := filterRefreshBlocklists
|
nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false)
|
||||||
if fj.Whitelist {
|
|
||||||
flags = filterRefreshAllowlists
|
|
||||||
}
|
|
||||||
nUpdated, _ := f.refreshFilters(flags, true)
|
|
||||||
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
||||||
// if not - we restart the filtering ourselves
|
// if not - we restart the filtering ourselves
|
||||||
restart = false
|
restart = false
|
||||||
|
@ -249,11 +245,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
if restart {
|
if restart {
|
||||||
enableFilters(true)
|
d.EnableFilters(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||||
// This use of ReadAll is safe, because request's body is now limited.
|
// This use of ReadAll is safe, because request's body is now limited.
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -262,12 +258,12 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
config.UserRules = strings.Split(string(body), "\n")
|
d.UserRules = strings.Split(string(body), "\n")
|
||||||
onConfigModified()
|
d.ConfigModified()
|
||||||
enableFilters(true)
|
d.EnableFilters(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||||
type Req struct {
|
type Req struct {
|
||||||
White bool `json:"whitelist"`
|
White bool `json:"whitelist"`
|
||||||
}
|
}
|
||||||
|
@ -285,35 +281,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
flags := filterRefreshBlocklists
|
var ok bool
|
||||||
if req.White {
|
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
|
||||||
flags = filterRefreshAllowlists
|
if !ok {
|
||||||
}
|
aghhttp.Error(
|
||||||
func() {
|
r,
|
||||||
// Temporarily unlock the Context.controlLock because the
|
w,
|
||||||
// f.refreshFilters waits for it to be unlocked but it's
|
http.StatusInternalServerError,
|
||||||
// actually locked in ensure wrapper.
|
"filters update procedure is already running",
|
||||||
//
|
)
|
||||||
// TODO(e.burkov): Reconsider this messy syncing process.
|
|
||||||
Context.controlLock.Unlock()
|
|
||||||
defer Context.controlLock.Lock()
|
|
||||||
|
|
||||||
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
js, err := json.Marshal(resp)
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
err = json.NewEncoder(w).Encode(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write(js)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type filterJSON struct {
|
type filterJSON struct {
|
||||||
|
@ -333,7 +321,7 @@ type filteringConfig struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterToJSON(f filter) filterJSON {
|
func filterToJSON(f FilterYAML) filterJSON {
|
||||||
fj := filterJSON{
|
fj := filterJSON{
|
||||||
ID: f.ID,
|
ID: f.ID,
|
||||||
Enabled: f.Enabled,
|
Enabled: f.Enabled,
|
||||||
|
@ -350,21 +338,21 @@ func filterToJSON(f filter) filterJSON {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get filtering configuration
|
// Get filtering configuration
|
||||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
resp := filteringConfig{}
|
resp := filteringConfig{}
|
||||||
config.RLock()
|
d.filtersMu.RLock()
|
||||||
resp.Enabled = config.DNS.FilteringEnabled
|
resp.Enabled = d.FilteringEnabled
|
||||||
resp.Interval = config.DNS.FiltersUpdateIntervalHours
|
resp.Interval = d.FiltersUpdateIntervalHours
|
||||||
for _, f := range config.Filters {
|
for _, f := range d.Filters {
|
||||||
fj := filterToJSON(f)
|
fj := filterToJSON(f)
|
||||||
resp.Filters = append(resp.Filters, fj)
|
resp.Filters = append(resp.Filters, fj)
|
||||||
}
|
}
|
||||||
for _, f := range config.WhitelistFilters {
|
for _, f := range d.WhitelistFilters {
|
||||||
fj := filterToJSON(f)
|
fj := filterToJSON(f)
|
||||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||||
}
|
}
|
||||||
resp.UserRules = config.UserRules
|
resp.UserRules = d.UserRules
|
||||||
config.RUnlock()
|
d.filtersMu.RUnlock()
|
||||||
|
|
||||||
jsonVal, err := json.Marshal(resp)
|
jsonVal, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -380,7 +368,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set filtering configuration
|
// Set filtering configuration
|
||||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
req := filteringConfig{}
|
req := filteringConfig{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -389,22 +377,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
if !ValidateUpdateIvl(req.Interval) {
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
config.Lock()
|
d.filtersMu.Lock()
|
||||||
defer config.Unlock()
|
defer d.filtersMu.Unlock()
|
||||||
|
|
||||||
config.DNS.FilteringEnabled = req.Enabled
|
d.FilteringEnabled = req.Enabled
|
||||||
config.DNS.FiltersUpdateIntervalHours = req.Interval
|
d.FiltersUpdateIntervalHours = req.Interval
|
||||||
}()
|
}()
|
||||||
|
|
||||||
onConfigModified()
|
d.ConfigModified()
|
||||||
enableFilters(true)
|
d.EnableFilters(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
type checkHostRespRule struct {
|
type checkHostRespRule struct {
|
||||||
|
@ -435,15 +423,15 @@ type checkHostResp struct {
|
||||||
FilterID int64 `json:"filter_id"`
|
FilterID int64 `json:"filter_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
q := r.URL.Query()
|
host := r.URL.Query().Get("name")
|
||||||
host := q.Get("name")
|
|
||||||
|
|
||||||
setts := Context.dnsFilter.GetConfig()
|
setts := d.GetConfig()
|
||||||
setts.FilteringEnabled = true
|
setts.FilteringEnabled = true
|
||||||
setts.ProtectionEnabled = true
|
setts.ProtectionEnabled = true
|
||||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
|
||||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
d.ApplyBlockedServices(&setts, nil)
|
||||||
|
result, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(
|
aghhttp.Error(
|
||||||
r,
|
r,
|
||||||
|
@ -457,18 +445,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := checkHostResp{}
|
rulesLen := len(result.Rules)
|
||||||
resp.Reason = result.Reason.String()
|
resp := checkHostResp{
|
||||||
resp.SvcName = result.ServiceName
|
Reason: result.Reason.String(),
|
||||||
resp.CanonName = result.CanonName
|
SvcName: result.ServiceName,
|
||||||
resp.IPList = result.IPList
|
CanonName: result.CanonName,
|
||||||
|
IPList: result.IPList,
|
||||||
|
Rules: make([]*checkHostRespRule, len(result.Rules)),
|
||||||
|
}
|
||||||
|
|
||||||
if len(result.Rules) > 0 {
|
if rulesLen > 0 {
|
||||||
resp.FilterID = result.Rules[0].FilterListID
|
resp.FilterID = result.Rules[0].FilterListID
|
||||||
resp.Rule = result.Rules[0].Text
|
resp.Rule = result.Rules[0].Text
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
|
|
||||||
for i, r := range result.Rules {
|
for i, r := range result.Rules {
|
||||||
resp.Rules[i] = &checkHostRespRule{
|
resp.Rules[i] = &checkHostRespRule{
|
||||||
FilterListID: r.FilterListID,
|
FilterListID: r.FilterListID,
|
||||||
|
@ -476,28 +466,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
js, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = w.Write(js)
|
err = json.NewEncoder(w).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterFilteringHandlers - register handlers
|
// RegisterFilteringHandlers - register handlers
|
||||||
func (f *Filtering) RegisterFilteringHandlers() {
|
func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||||
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
|
registerHTTP := d.HTTPRegister
|
||||||
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
|
if registerHTTP == nil {
|
||||||
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
|
return
|
||||||
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
}
|
||||||
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
|
|
||||||
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
|
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||||
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
|
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||||
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
|
registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||||
|
|
||||||
|
registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||||
|
registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||||
|
registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||||
|
|
||||||
|
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||||
|
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||||
|
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||||
|
|
||||||
|
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||||
|
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||||
|
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||||
|
|
||||||
|
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||||
|
registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||||
|
registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||||
|
|
||||||
|
registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh)
|
||||||
|
registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules)
|
||||||
|
registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
// ValidateUpdateIvl returns false if i is not a valid filters update interval.
|
||||||
|
func ValidateUpdateIvl(i uint32) bool {
|
||||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||||
}
|
}
|
|
@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||||
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
||||||
`
|
`
|
||||||
|
|
||||||
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||||
setts := &Settings{
|
setts := &Settings{
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package home
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -8,63 +8,29 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
// filterDir is the subdirectory of a data directory to store downloaded
|
||||||
|
// filters.
|
||||||
|
const filterDir = "filters"
|
||||||
|
|
||||||
// Filtering - module object
|
// nextFilterID is a way to seed a unique ID generation.
|
||||||
type Filtering struct {
|
//
|
||||||
// conf FilteringConf
|
// TODO(e.burkov): Use more deterministic approach.
|
||||||
refreshStatus uint32 // 0:none; 1:in progress
|
var nextFilterID = time.Now().Unix()
|
||||||
refreshLock sync.Mutex
|
|
||||||
filterTitleRegexp *regexp.Regexp
|
|
||||||
}
|
|
||||||
|
|
||||||
// Init - initialize the module
|
// FilterYAML respresents a filter list in the configuration file.
|
||||||
func (f *Filtering) Init() {
|
//
|
||||||
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
// TODO(e.burkov): Investigate if the field oredering is important.
|
||||||
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755)
|
type FilterYAML struct {
|
||||||
f.loadFilters(config.Filters)
|
|
||||||
f.loadFilters(config.WhitelistFilters)
|
|
||||||
deduplicateFilters()
|
|
||||||
updateUniqueFilterID(config.Filters)
|
|
||||||
updateUniqueFilterID(config.WhitelistFilters)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start - start the module
|
|
||||||
func (f *Filtering) Start() {
|
|
||||||
f.RegisterFilteringHandlers()
|
|
||||||
|
|
||||||
// Here we should start updating filters,
|
|
||||||
// but currently we can't wake up the periodic task to do so.
|
|
||||||
// So for now we just start this periodic task from here.
|
|
||||||
go f.periodicallyRefreshFilters()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close - close the module
|
|
||||||
func (f *Filtering) Close() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultFilters() []filter {
|
|
||||||
return []filter{
|
|
||||||
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
|
|
||||||
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
|
||||||
type filter struct {
|
|
||||||
Enabled bool
|
Enabled bool
|
||||||
URL string // URL or a file path
|
URL string // URL or a file path
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
|
@ -73,44 +39,58 @@ type filter struct {
|
||||||
checksum uint32 // checksum of the file data
|
checksum uint32 // checksum of the file data
|
||||||
white bool
|
white bool
|
||||||
|
|
||||||
filtering.Filter `yaml:",inline"`
|
Filter `yaml:",inline"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear filter rules
|
||||||
|
func (filter *FilterYAML) unload() {
|
||||||
|
filter.RulesCount = 0
|
||||||
|
filter.checksum = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path to the filter contents
|
||||||
|
func (filter *FilterYAML) Path(dataDir string) string {
|
||||||
|
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
statusFound = 1
|
statusFound = 1 << iota
|
||||||
statusEnabledChanged = 2
|
statusEnabledChanged
|
||||||
statusURLChanged = 4
|
statusURLChanged
|
||||||
statusURLExists = 8
|
statusURLExists
|
||||||
statusUpdateRequired = 0x10
|
statusUpdateRequired
|
||||||
)
|
)
|
||||||
|
|
||||||
// Update properties for a filter specified by its URL
|
// Update properties for a filter specified by its URL
|
||||||
// Return status* flags.
|
// Return status* flags.
|
||||||
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
|
func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int {
|
||||||
r := 0
|
r := 0
|
||||||
config.Lock()
|
d.filtersMu.Lock()
|
||||||
defer config.Unlock()
|
defer d.filtersMu.Unlock()
|
||||||
|
|
||||||
filters := &config.Filters
|
filters := d.Filters
|
||||||
if whitelist {
|
if whitelist {
|
||||||
filters = &config.WhitelistFilters
|
filters = d.WhitelistFilters
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range *filters {
|
i := slices.IndexFunc(filters, func(filt FilterYAML) bool {
|
||||||
filt := &(*filters)[i]
|
return filt.URL == url
|
||||||
if filt.URL != url {
|
})
|
||||||
continue
|
if i == -1 {
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("filter: set properties: %s: {%s %s %v}",
|
filt := &filters[i]
|
||||||
filt.URL, newf.Name, newf.URL, newf.Enabled)
|
|
||||||
|
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||||
filt.Name = newf.Name
|
filt.Name = newf.Name
|
||||||
|
|
||||||
if filt.URL != newf.URL {
|
if filt.URL != newf.URL {
|
||||||
r |= statusURLChanged | statusUpdateRequired
|
r |= statusURLChanged | statusUpdateRequired
|
||||||
if filterExistsNoLock(newf.URL) {
|
if d.filterExistsNoLock(newf.URL) {
|
||||||
return statusURLExists
|
return statusURLExists
|
||||||
}
|
}
|
||||||
|
|
||||||
filt.URL = newf.URL
|
filt.URL = newf.URL
|
||||||
filt.unload()
|
filt.unload()
|
||||||
filt.LastUpdated = time.Time{}
|
filt.LastUpdated = time.Time{}
|
||||||
|
@ -123,10 +103,13 @@ func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool)
|
||||||
filt.Enabled = newf.Enabled
|
filt.Enabled = newf.Enabled
|
||||||
if filt.Enabled {
|
if filt.Enabled {
|
||||||
if (r & statusURLChanged) == 0 {
|
if (r & statusURLChanged) == 0 {
|
||||||
e := f.load(filt)
|
err := d.load(filt)
|
||||||
if e != nil {
|
if err != nil {
|
||||||
// This isn't a fatal error,
|
// TODO(e.burkov): It seems the error is only returned when
|
||||||
// because it may occur when someone removes the file from disk.
|
// the file exists and couldn't be open. Investigate and
|
||||||
|
// improve.
|
||||||
|
log.Error("loading filter %d: %s", filt.ID, err)
|
||||||
|
|
||||||
filt.LastUpdated = time.Time{}
|
filt.LastUpdated = time.Time{}
|
||||||
filt.checksum = 0
|
filt.checksum = 0
|
||||||
filt.RulesCount = 0
|
filt.RulesCount = 0
|
||||||
|
@ -139,25 +122,25 @@ func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r | statusFound
|
return r | statusFound
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return TRUE if a filter with this URL exists
|
// Return TRUE if a filter with this URL exists
|
||||||
func filterExists(url string) bool {
|
func (d *DNSFilter) filterExists(url string) bool {
|
||||||
config.RLock()
|
d.filtersMu.RLock()
|
||||||
r := filterExistsNoLock(url)
|
defer d.filtersMu.RUnlock()
|
||||||
config.RUnlock()
|
|
||||||
|
r := d.filterExistsNoLock(url)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterExistsNoLock(url string) bool {
|
func (d *DNSFilter) filterExistsNoLock(url string) bool {
|
||||||
for _, f := range config.Filters {
|
for _, f := range d.Filters {
|
||||||
if f.URL == url {
|
if f.URL == url {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, f := range config.WhitelistFilters {
|
for _, f := range d.WhitelistFilters {
|
||||||
if f.URL == url {
|
if f.URL == url {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool {
|
||||||
|
|
||||||
// Add a filter
|
// Add a filter
|
||||||
// Return FALSE if a filter with this URL exists
|
// Return FALSE if a filter with this URL exists
|
||||||
func filterAdd(f filter) bool {
|
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
|
||||||
config.Lock()
|
d.filtersMu.Lock()
|
||||||
defer config.Unlock()
|
defer d.filtersMu.Unlock()
|
||||||
|
|
||||||
// Check for duplicates
|
// Check for duplicates
|
||||||
if filterExistsNoLock(f.URL) {
|
if d.filterExistsNoLock(flt.URL) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.white {
|
if flt.white {
|
||||||
config.WhitelistFilters = append(config.WhitelistFilters, f)
|
d.WhitelistFilters = append(d.WhitelistFilters, flt)
|
||||||
} else {
|
} else {
|
||||||
config.Filters = append(config.Filters, f)
|
d.Filters = append(d.Filters, flt)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load filters from the disk
|
// Load filters from the disk
|
||||||
// And if any filter has zero ID, assign a new one
|
// And if any filter has zero ID, assign a new one
|
||||||
func (f *Filtering) loadFilters(array []filter) {
|
func (d *DNSFilter) loadFilters(array []FilterYAML) {
|
||||||
for i := range array {
|
for i := range array {
|
||||||
filter := &array[i] // otherwise we're operating on a copy
|
filter := &array[i] // otherwise we're operating on a copy
|
||||||
if filter.ID == 0 {
|
if filter.ID == 0 {
|
||||||
|
@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := f.load(filter)
|
err := d.load(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
|
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func deduplicateFilters() {
|
func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
|
||||||
// Deduplicate filters
|
urls := stringutil.NewSet()
|
||||||
i := 0 // output index, used for deletion later
|
lastIdx := 0
|
||||||
urls := map[string]bool{}
|
|
||||||
for _, filter := range config.Filters {
|
for _, filter := range filters {
|
||||||
if _, ok := urls[filter.URL]; !ok {
|
if !urls.Has(filter.URL) {
|
||||||
// we didn't see it before, keep it
|
urls.Add(filter.URL)
|
||||||
urls[filter.URL] = true // remember the URL
|
filters[lastIdx] = filter
|
||||||
config.Filters[i] = filter
|
lastIdx++
|
||||||
i++
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// all entries we want to keep are at front, delete the rest
|
return filters[:lastIdx]
|
||||||
config.Filters = config.Filters[:i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the next filter ID to max(filter.ID) + 1
|
// Set the next filter ID to max(filter.ID) + 1
|
||||||
func updateUniqueFilterID(filters []filter) {
|
func updateUniqueFilterID(filters []FilterYAML) {
|
||||||
for _, filter := range filters {
|
for _, filter := range filters {
|
||||||
if nextFilterID < filter.ID {
|
if nextFilterID < filter.ID {
|
||||||
nextFilterID = filter.ID + 1
|
nextFilterID = filter.ID + 1
|
||||||
|
@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets up a timer that will be checking for filters updates periodically
|
// Sets up a timer that will be checking for filters updates periodically
|
||||||
func (f *Filtering) periodicallyRefreshFilters() {
|
func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||||
const maxInterval = 1 * 60 * 60
|
const maxInterval = 1 * 60 * 60
|
||||||
intval := 5 // use a dynamically increasing time interval
|
intval := 5 // use a dynamically increasing time interval
|
||||||
for {
|
for {
|
||||||
isNetworkErr := false
|
isNetErr, ok := false, false
|
||||||
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
|
if d.FiltersUpdateIntervalHours != 0 {
|
||||||
f.refreshLock.Lock()
|
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||||
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists)
|
if ok && !isNetErr {
|
||||||
f.refreshLock.Unlock()
|
|
||||||
f.refreshStatus = 0
|
|
||||||
if !isNetworkErr {
|
|
||||||
intval = maxInterval
|
intval = maxInterval
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isNetworkErr {
|
if isNetErr {
|
||||||
intval *= 2
|
intval *= 2
|
||||||
if intval > maxInterval {
|
if intval > maxInterval {
|
||||||
intval = maxInterval
|
intval = maxInterval
|
||||||
|
@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh filters
|
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||||
// flags: filterRefresh*
|
// already going on.
|
||||||
// important:
|
|
||||||
//
|
//
|
||||||
// TRUE: ignore the fact that we're currently updating the filters
|
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
|
||||||
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
|
// sync.Mutex.TryLock.
|
||||||
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
|
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
|
||||||
if !important && !set {
|
if ok = d.refreshLock.TryLock(); !ok {
|
||||||
return 0, fmt.Errorf("filters update procedure is already running")
|
return 0, false, ok
|
||||||
}
|
}
|
||||||
|
defer d.refreshLock.Unlock()
|
||||||
|
|
||||||
f.refreshLock.Lock()
|
updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force)
|
||||||
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
|
|
||||||
f.refreshLock.Unlock()
|
return updated, isNetworkErr, ok
|
||||||
f.refreshStatus = 0
|
|
||||||
return nUpdated, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
|
// refreshFilters updates the lists and returns the number of updated ones.
|
||||||
var updateFilters []filter
|
// It's safe for concurrent use, but blocks at least until the previous
|
||||||
|
// refreshing is finished.
|
||||||
|
func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) {
|
||||||
|
d.refreshLock.Lock()
|
||||||
|
defer d.refreshLock.Unlock()
|
||||||
|
|
||||||
|
updated, _ = d.refreshFiltersIntl(block, allow, force)
|
||||||
|
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// listsToUpdate returns the slice of filter lists that could be updated.
|
||||||
|
func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
d.filtersMu.RLock()
|
||||||
|
defer d.filtersMu.RUnlock()
|
||||||
|
|
||||||
|
for i := range *filters {
|
||||||
|
flt := &(*filters)[i] // otherwise we will be operating on a copy
|
||||||
|
log.Debug("checking list at index %d: %v", i, flt)
|
||||||
|
|
||||||
|
if !flt.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !force {
|
||||||
|
exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour)
|
||||||
|
if now.Before(exp) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toUpd = append(toUpd, FilterYAML{
|
||||||
|
Filter: Filter{
|
||||||
|
ID: flt.ID,
|
||||||
|
},
|
||||||
|
URL: flt.URL,
|
||||||
|
Name: flt.Name,
|
||||||
|
checksum: flt.checksum,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return toUpd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) {
|
||||||
var updateFlags []bool // 'true' if filter data has changed
|
var updateFlags []bool // 'true' if filter data has changed
|
||||||
|
|
||||||
now := time.Now()
|
updateFilters := d.listsToUpdate(filters, force)
|
||||||
config.RLock()
|
|
||||||
for i := range *filters {
|
|
||||||
f := &(*filters)[i] // otherwise we will be operating on a copy
|
|
||||||
|
|
||||||
if !f.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
|
|
||||||
if !force && expireTime > now.Unix() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var uf filter
|
|
||||||
uf.ID = f.ID
|
|
||||||
uf.URL = f.URL
|
|
||||||
uf.Name = f.Name
|
|
||||||
uf.checksum = f.checksum
|
|
||||||
updateFilters = append(updateFilters, uf)
|
|
||||||
}
|
|
||||||
config.RUnlock()
|
|
||||||
|
|
||||||
if len(updateFilters) == 0 {
|
if len(updateFilters) == 0 {
|
||||||
return 0, nil, nil, false
|
return 0, nil, nil, false
|
||||||
}
|
}
|
||||||
|
@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||||
nfail := 0
|
nfail := 0
|
||||||
for i := range updateFilters {
|
for i := range updateFilters {
|
||||||
uf := &updateFilters[i]
|
uf := &updateFilters[i]
|
||||||
updated, err := f.update(uf)
|
updated, err := d.update(uf)
|
||||||
updateFlags = append(updateFlags, updated)
|
updateFlags = append(updateFlags, updated)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nfail++
|
nfail++
|
||||||
|
@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||||
uf := &updateFilters[i]
|
uf := &updateFilters[i]
|
||||||
updated := updateFlags[i]
|
updated := updateFlags[i]
|
||||||
|
|
||||||
config.Lock()
|
d.filtersMu.Lock()
|
||||||
for k := range *filters {
|
for k := range *filters {
|
||||||
f := &(*filters)[k]
|
f := &(*filters)[k]
|
||||||
if f.ID != uf.ID || f.URL != uf.URL {
|
if f.ID != uf.ID || f.URL != uf.URL {
|
||||||
|
@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||||
f.checksum = uf.checksum
|
f.checksum = uf.checksum
|
||||||
updateCount++
|
updateCount++
|
||||||
}
|
}
|
||||||
config.Unlock()
|
d.filtersMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return updateCount, updateFilters, updateFlags, false
|
return updateCount, updateFilters, updateFlags, false
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
// refreshFiltersIntl checks filters and updates them if necessary. If force is
|
||||||
filterRefreshForce = 1 // ignore last file modification date
|
// true, it ignores the filter.LastUpdated field value.
|
||||||
filterRefreshAllowlists = 2 // update allow-lists
|
|
||||||
filterRefreshBlocklists = 4 // update block-lists
|
|
||||||
)
|
|
||||||
|
|
||||||
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
|
|
||||||
// force is true, it ignores the filter.LastUpdated field value.
|
|
||||||
//
|
//
|
||||||
// Algorithm:
|
// Algorithm:
|
||||||
//
|
//
|
||||||
|
@ -378,53 +372,49 @@ const (
|
||||||
// that this method works only on Unix systems. On Windows, don't pass
|
// that this method works only on Unix systems. On Windows, don't pass
|
||||||
// files to filtering, pass the whole data.
|
// files to filtering, pass the whole data.
|
||||||
//
|
//
|
||||||
// refreshFiltersIfNecessary returns the number of updated filters. It also
|
// refreshFiltersIntl returns the number of updated filters. It also returns
|
||||||
// returns true if there was a network error and nothing could be updated.
|
// true if there was a network error and nothing could be updated.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov, e.burkov): What the hell?
|
// TODO(a.garipov, e.burkov): What the hell?
|
||||||
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
|
func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||||
log.Debug("Filters: updating...")
|
log.Debug("filtering: updating...")
|
||||||
|
|
||||||
updateCount := 0
|
updNum := 0
|
||||||
var updateFilters []filter
|
var lists []FilterYAML
|
||||||
var updateFlags []bool
|
var toUpd []bool
|
||||||
netError := false
|
isNetErr := false
|
||||||
netErrorW := false
|
|
||||||
force := false
|
if block {
|
||||||
if (flags & filterRefreshForce) != 0 {
|
updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force)
|
||||||
force = true
|
|
||||||
}
|
}
|
||||||
if (flags & filterRefreshBlocklists) != 0 {
|
if allow {
|
||||||
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
|
updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force)
|
||||||
|
|
||||||
|
updNum += updNumAl
|
||||||
|
lists = append(lists, listsAl...)
|
||||||
|
toUpd = append(toUpd, toUpdAl...)
|
||||||
|
isNetErr = isNetErr || isNetErrAl
|
||||||
}
|
}
|
||||||
if (flags & filterRefreshAllowlists) != 0 {
|
if isNetErr {
|
||||||
updateCountW := 0
|
|
||||||
var updateFiltersW []filter
|
|
||||||
var updateFlagsW []bool
|
|
||||||
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
|
|
||||||
updateCount += updateCountW
|
|
||||||
updateFilters = append(updateFilters, updateFiltersW...)
|
|
||||||
updateFlags = append(updateFlags, updateFlagsW...)
|
|
||||||
}
|
|
||||||
if netError && netErrorW {
|
|
||||||
return 0, true
|
return 0, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateCount != 0 {
|
if updNum != 0 {
|
||||||
enableFilters(false)
|
d.EnableFilters(false)
|
||||||
|
|
||||||
for i := range updateFilters {
|
for i := range lists {
|
||||||
uf := &updateFilters[i]
|
uf := &lists[i]
|
||||||
updated := updateFlags[i]
|
updated := toUpd[i]
|
||||||
if !updated {
|
if !updated {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
_ = os.Remove(uf.Path() + ".old")
|
_ = os.Remove(uf.Path(d.DataDir) + ".old")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Filters: update finished")
|
log.Debug("filtering: update finished")
|
||||||
return updateCount, false
|
|
||||||
|
return updNum, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||||
|
@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||||
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||||
rulesCount := 0
|
rulesCount := 0
|
||||||
name := ""
|
name := ""
|
||||||
seenTitle := false
|
seenTitle := false
|
||||||
|
@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
//
|
//
|
||||||
} else if line[0] == '!' {
|
} else if line[0] == '!' {
|
||||||
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||||
name = m[0][1]
|
name = m[0][1]
|
||||||
seenTitle = true
|
seenTitle = true
|
||||||
|
@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform upgrade on a filter and update LastUpdated value
|
// Perform upgrade on a filter and update LastUpdated value
|
||||||
func (f *Filtering) update(filter *filter) (bool, error) {
|
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||||
b, err := f.updateIntl(filter)
|
b, err := d.updateIntl(filter)
|
||||||
filter.LastUpdated = time.Now()
|
filter.LastUpdated = time.Now()
|
||||||
if !b {
|
if !b {
|
||||||
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
|
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Error("os.Chtimes(): %v", e)
|
log.Error("os.Chtimes(): %v", e)
|
||||||
}
|
}
|
||||||
|
@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) {
|
||||||
return b, err
|
return b, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) {
|
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
|
||||||
htmlTest := true
|
htmlTest := true
|
||||||
firstChunk := make([]byte, 4*1024)
|
firstChunk := make([]byte, 4*1024)
|
||||||
firstChunkLen := 0
|
firstChunkLen := 0
|
||||||
|
@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in
|
||||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||||
// according to updated. It also saves new values of flt's name, rules number
|
// according to updated. It also saves new values of flt's name, rules number
|
||||||
// and checksum if sucсeeded.
|
// and checksum if sucсeeded.
|
||||||
func finalizeUpdate(
|
func (d *DNSFilter) finalizeUpdate(
|
||||||
f *os.File,
|
file *os.File,
|
||||||
flt *filter,
|
flt *FilterYAML,
|
||||||
updated bool,
|
updated bool,
|
||||||
name string,
|
name string,
|
||||||
rnum int,
|
rnum int,
|
||||||
cs uint32,
|
cs uint32,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
tmpFileName := f.Name()
|
tmpFileName := file.Name()
|
||||||
|
|
||||||
// Close the file before renaming it because it's required on Windows.
|
// Close the file before renaming it because it's required on Windows.
|
||||||
//
|
//
|
||||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||||
if err = f.Close(); err != nil {
|
if err = file.Close(); err != nil {
|
||||||
return fmt.Errorf("closing temporary file: %w", err)
|
return fmt.Errorf("closing temporary file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -562,9 +552,9 @@ func finalizeUpdate(
|
||||||
return os.Remove(tmpFileName)
|
return os.Remove(tmpFileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path())
|
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||||
|
|
||||||
if err = os.Rename(tmpFileName, flt.Path()); err != nil {
|
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
|
||||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -578,12 +568,12 @@ func finalizeUpdate(
|
||||||
// processUpdate copies filter's content from src to dst and returns the name,
|
// processUpdate copies filter's content from src to dst and returns the name,
|
||||||
// rules number, and checksum for it. It also returns the number of bytes read
|
// rules number, and checksum for it. It also returns the number of bytes read
|
||||||
// from src.
|
// from src.
|
||||||
func (f *Filtering) processUpdate(
|
func (d *DNSFilter) processUpdate(
|
||||||
src io.Reader,
|
src io.Reader,
|
||||||
dst *os.File,
|
dst *os.File,
|
||||||
flt *filter,
|
flt *FilterYAML,
|
||||||
) (name string, rnum int, cs uint32, n int, err error) {
|
) (name string, rnum int, cs uint32, n int, err error) {
|
||||||
if n, err = f.read(src, dst, flt); err != nil {
|
if n, err = d.read(src, dst, flt); err != nil {
|
||||||
return "", 0, 0, 0, err
|
return "", 0, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -591,14 +581,14 @@ func (f *Filtering) processUpdate(
|
||||||
return "", 0, 0, 0, err
|
return "", 0, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rnum, cs, name = f.parseFilterContents(dst)
|
rnum, cs, name = d.parseFilterContents(dst)
|
||||||
|
|
||||||
return name, rnum, cs, n, nil
|
return name, rnum, cs, n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||||
// the actual update has been performed.
|
// the actual update has been performed.
|
||||||
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||||
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
|
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
|
@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||||
var cs uint32
|
var cs uint32
|
||||||
|
|
||||||
var tmpFile *os.File
|
var tmpFile *os.File
|
||||||
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "")
|
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||||
ok = ok && err == nil
|
ok = ok && err == nil
|
||||||
if ok {
|
if ok {
|
||||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||||
|
@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||||
r = file
|
r = file
|
||||||
} else {
|
} else {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
resp, err = Context.client.Get(flt.URL)
|
resp, err = d.HTTPClient.Get(flt.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
|
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
|
||||||
|
|
||||||
|
@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||||
r = resp.Body
|
r = resp.Body
|
||||||
}
|
}
|
||||||
|
|
||||||
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt)
|
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
|
||||||
|
|
||||||
return cs != flt.checksum, err
|
return cs != flt.checksum, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// loads filter contents from the file in dataDir
|
// loads filter contents from the file in dataDir
|
||||||
func (f *Filtering) load(filter *filter) (err error) {
|
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||||
filterFilePath := filter.Path()
|
filterFilePath := filter.Path(d.DataDir)
|
||||||
|
|
||||||
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath)
|
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
|
||||||
|
|
||||||
file, err := os.Open(filterFilePath)
|
file, err := os.Open(filterFilePath)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) {
|
||||||
|
|
||||||
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
||||||
|
|
||||||
rulesCount, checksum, _ := f.parseFilterContents(file)
|
rulesCount, checksum, _ := d.parseFilterContents(file)
|
||||||
|
|
||||||
filter.RulesCount = rulesCount
|
filter.RulesCount = rulesCount
|
||||||
filter.checksum = checksum
|
filter.checksum = checksum
|
||||||
|
@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear filter rules
|
func (d *DNSFilter) EnableFilters(async bool) {
|
||||||
func (filter *filter) unload() {
|
d.filtersMu.RLock()
|
||||||
filter.RulesCount = 0
|
defer d.filtersMu.RUnlock()
|
||||||
filter.checksum = 0
|
|
||||||
|
d.enableFiltersLocked(async)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Path to the filter contents
|
func (d *DNSFilter) enableFiltersLocked(async bool) {
|
||||||
func (filter *filter) Path() string {
|
filters := []Filter{{
|
||||||
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
ID: CustomListID,
|
||||||
}
|
Data: []byte(strings.Join(d.UserRules, "\n")),
|
||||||
|
|
||||||
func enableFilters(async bool) {
|
|
||||||
config.RLock()
|
|
||||||
defer config.RUnlock()
|
|
||||||
|
|
||||||
enableFiltersLocked(async)
|
|
||||||
}
|
|
||||||
|
|
||||||
func enableFiltersLocked(async bool) {
|
|
||||||
filters := []filtering.Filter{{
|
|
||||||
ID: filtering.CustomListID,
|
|
||||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, filter := range config.Filters {
|
for _, filter := range d.Filters {
|
||||||
if !filter.Enabled {
|
if !filter.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
filters = append(filters, filtering.Filter{
|
filters = append(filters, Filter{
|
||||||
ID: filter.ID,
|
ID: filter.ID,
|
||||||
FilePath: filter.Path(),
|
FilePath: filter.Path(d.DataDir),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var allowFilters []filtering.Filter
|
var allowFilters []Filter
|
||||||
for _, filter := range config.WhitelistFilters {
|
for _, filter := range d.WhitelistFilters {
|
||||||
if !filter.Enabled {
|
if !filter.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
allowFilters = append(allowFilters, filtering.Filter{
|
allowFilters = append(allowFilters, Filter{
|
||||||
ID: filter.ID,
|
ID: filter.ID,
|
||||||
FilePath: filter.Path(),
|
FilePath: filter.Path(d.DataDir),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
if err := d.SetFilters(filters, allowFilters, async); err != nil {
|
||||||
log.Debug("enabling filters: %s", err)
|
log.Debug("enabling filters: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled)
|
d.SetEnabled(d.FilteringEnabled)
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package home
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
@ -51,15 +51,17 @@ func TestFilters(t *testing.T) {
|
||||||
|
|
||||||
l := testStartFilterListener(t, &fltContent)
|
l := testStartFilterListener(t, &fltContent)
|
||||||
|
|
||||||
Context = homeContext{
|
tempDir := t.TempDir()
|
||||||
workDir: t.TempDir(),
|
|
||||||
client: &http.Client{
|
filters, err := New(&Config{
|
||||||
|
DataDir: tempDir,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
},
|
},
|
||||||
}
|
}, nil)
|
||||||
Context.filters.Init()
|
require.NoError(t, err)
|
||||||
|
|
||||||
f := &filter{
|
f := &FilterYAML{
|
||||||
URL: (&url.URL{
|
URL: (&url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
Host: (&netutil.IPPort{
|
Host: (&netutil.IPPort{
|
||||||
|
@ -71,21 +73,22 @@ func TestFilters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
||||||
ok, err := Context.filters.update(f)
|
var ok bool
|
||||||
|
ok, err = filters.update(f)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
want(t, ok)
|
want(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, wantRulesCount, f.RulesCount)
|
assert.Equal(t, wantRulesCount, f.RulesCount)
|
||||||
|
|
||||||
var dir []fs.DirEntry
|
var dir []fs.DirEntry
|
||||||
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir))
|
dir, err = os.ReadDir(filepath.Join(tempDir, filterDir))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, dir, 1)
|
assert.Len(t, dir, 1)
|
||||||
|
|
||||||
require.FileExists(t, f.Path())
|
require.FileExists(t, f.Path(tempDir))
|
||||||
|
|
||||||
err = Context.filters.load(f)
|
err = filters.load(f)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,11 +108,9 @@ func TestFilters(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("load_unload", func(t *testing.T) {
|
t.Run("load_unload", func(t *testing.T) {
|
||||||
err := Context.filters.load(f)
|
err = filters.load(f)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
f.unload()
|
f.unload()
|
||||||
})
|
})
|
||||||
|
|
||||||
require.NoError(t, os.Remove(f.Path()))
|
|
||||||
}
|
}
|
|
@ -6,7 +6,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -24,6 +27,7 @@ import (
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The IDs of built-in filter lists.
|
// The IDs of built-in filter lists.
|
||||||
|
@ -69,8 +73,13 @@ type Config struct {
|
||||||
// enabled is used to be returned within Settings.
|
// enabled is used to be returned within Settings.
|
||||||
//
|
//
|
||||||
// It is of type uint32 to be accessed by atomic.
|
// It is of type uint32 to be accessed by atomic.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Use atomic.Bool in Go 1.19.
|
||||||
enabled uint32
|
enabled uint32
|
||||||
|
|
||||||
|
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||||
|
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||||
|
|
||||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||||
|
@ -98,6 +107,24 @@ type Config struct {
|
||||||
|
|
||||||
// CustomResolver is the resolver used by DNSFilter.
|
// CustomResolver is the resolver used by DNSFilter.
|
||||||
CustomResolver Resolver `yaml:"-"`
|
CustomResolver Resolver `yaml:"-"`
|
||||||
|
|
||||||
|
// HTTPClient is the client to use for updating the remote filters.
|
||||||
|
HTTPClient *http.Client `yaml:"-"`
|
||||||
|
|
||||||
|
// DataDir is used to store filters' contents.
|
||||||
|
DataDir string `yaml:"-"`
|
||||||
|
|
||||||
|
// filtersMu protects filter lists.
|
||||||
|
filtersMu *sync.RWMutex
|
||||||
|
|
||||||
|
// Filters are the blocking filter lists.
|
||||||
|
Filters []FilterYAML `yaml:"-"`
|
||||||
|
|
||||||
|
// WhitelistFilters are the allowing filter lists.
|
||||||
|
WhitelistFilters []FilterYAML `yaml:"-"`
|
||||||
|
|
||||||
|
// UserRules is the global list of custom rules.
|
||||||
|
UserRules []string `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupStats store stats collected during safebrowsing or parental checks
|
// LookupStats store stats collected during safebrowsing or parental checks
|
||||||
|
@ -130,8 +157,10 @@ type hostChecker struct {
|
||||||
type DNSFilter struct {
|
type DNSFilter struct {
|
||||||
rulesStorage *filterlist.RuleStorage
|
rulesStorage *filterlist.RuleStorage
|
||||||
filteringEngine *urlfilter.DNSEngine
|
filteringEngine *urlfilter.DNSEngine
|
||||||
|
|
||||||
rulesStorageAllow *filterlist.RuleStorage
|
rulesStorageAllow *filterlist.RuleStorage
|
||||||
filteringEngineAllow *urlfilter.DNSEngine
|
filteringEngineAllow *urlfilter.DNSEngine
|
||||||
|
|
||||||
engineLock sync.RWMutex
|
engineLock sync.RWMutex
|
||||||
|
|
||||||
parentalServer string // access via methods
|
parentalServer string // access via methods
|
||||||
|
@ -156,6 +185,12 @@ type DNSFilter struct {
|
||||||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||||
resolver Resolver
|
resolver Resolver
|
||||||
|
|
||||||
|
refreshLock *sync.Mutex
|
||||||
|
|
||||||
|
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||||
|
// filter list.
|
||||||
|
filterTitleRegexp *regexp.Regexp
|
||||||
|
|
||||||
hostCheckers []hostChecker
|
hostCheckers []hostChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,7 +203,7 @@ type Filter struct {
|
||||||
Data []byte `yaml:"-"`
|
Data []byte `yaml:"-"`
|
||||||
|
|
||||||
// ID is automatically assigned when filter is added using nextFilterID.
|
// ID is automatically assigned when filter is added using nextFilterID.
|
||||||
ID int64
|
ID int64 `yaml:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reason holds an enum detailing why it was filtered or not filtered
|
// Reason holds an enum detailing why it was filtered or not filtered
|
||||||
|
@ -245,15 +280,7 @@ func (r Reason) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// In returns true if reasons include r.
|
// In returns true if reasons include r.
|
||||||
func (r Reason) In(reasons ...Reason) (ok bool) {
|
func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) }
|
||||||
for _, reason := range reasons {
|
|
||||||
if r == reason {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEnabled sets the status of the *DNSFilter.
|
// SetEnabled sets the status of the *DNSFilter.
|
||||||
func (d *DNSFilter) SetEnabled(enabled bool) {
|
func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||||
|
@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||||
if enabled {
|
if enabled {
|
||||||
i = 1
|
i = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
atomic.StoreUint32(&d.enabled, uint32(i))
|
atomic.StoreUint32(&d.enabled, uint32(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
||||||
|
|
||||||
// WriteDiskConfig - write configuration
|
// WriteDiskConfig - write configuration
|
||||||
func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||||
|
func() {
|
||||||
d.confLock.Lock()
|
d.confLock.Lock()
|
||||||
defer d.confLock.Unlock()
|
defer d.confLock.Unlock()
|
||||||
|
|
||||||
*c = d.Config
|
*c = d.Config
|
||||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||||
|
}()
|
||||||
|
|
||||||
|
d.filtersMu.RLock()
|
||||||
|
defer d.filtersMu.RUnlock()
|
||||||
|
|
||||||
|
c.Filters = slices.Clone(d.Filters)
|
||||||
|
c.WhitelistFilters = slices.Clone(d.WhitelistFilters)
|
||||||
|
c.UserRules = slices.Clone(d.UserRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneRewrites returns a deep copy of entries.
|
// cloneRewrites returns a deep copy of entries.
|
||||||
|
@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
|
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
|
||||||
|
defer d.filtersInitializerLock.Unlock()
|
||||||
|
|
||||||
// remove all pending tasks
|
// remove all pending tasks
|
||||||
stop := false
|
stop := false
|
||||||
for !stop {
|
for !stop {
|
||||||
|
@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.filtersInitializerChan <- params
|
d.filtersInitializerChan <- params
|
||||||
d.filtersInitializerLock.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() {
|
||||||
func (d *DNSFilter) Close() {
|
func (d *DNSFilter) Close() {
|
||||||
d.engineLock.Lock()
|
d.engineLock.Lock()
|
||||||
defer d.engineLock.Unlock()
|
defer d.engineLock.Unlock()
|
||||||
|
|
||||||
d.reset()
|
d.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) reset() {
|
func (d *DNSFilter) reset() {
|
||||||
var err error
|
|
||||||
|
|
||||||
if d.rulesStorage != nil {
|
if d.rulesStorage != nil {
|
||||||
err = d.rulesStorage.Close()
|
if err := d.rulesStorage.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error("filtering: rulesStorage.Close: %s", err)
|
log.Error("filtering: rulesStorage.Close: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.rulesStorageAllow != nil {
|
if d.rulesStorageAllow != nil {
|
||||||
err = d.rulesStorageAllow.Close()
|
if err := d.rulesStorageAllow.Close(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Error("filtering: rulesStorageAllow.Close: %s", err)
|
log.Error("filtering: rulesStorageAllow.Close: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -885,12 +920,14 @@ func InitModule() {
|
||||||
initBlockedServices()
|
initBlockedServices()
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates properly initialized DNS Filter that is ready to be used.
|
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
// be non-nil.
|
||||||
|
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||||
d = &DNSFilter{
|
d = &DNSFilter{
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
|
refreshLock: &sync.Mutex{},
|
||||||
|
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
|
||||||
}
|
}
|
||||||
if c != nil {
|
|
||||||
|
|
||||||
d.safebrowsingCache = cache.New(cache.Config{
|
d.safebrowsingCache = cache.New(cache.Config{
|
||||||
EnableLRU: true,
|
EnableLRU: true,
|
||||||
|
@ -905,9 +942,8 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||||
MaxSize: c.ParentalCacheSize,
|
MaxSize: c.ParentalCacheSize,
|
||||||
})
|
})
|
||||||
|
|
||||||
if c.CustomResolver != nil {
|
if r := c.CustomResolver; r != nil {
|
||||||
d.resolver = c.CustomResolver
|
d.resolver = r
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
d.hostCheckers = []hostChecker{{
|
d.hostCheckers = []hostChecker{{
|
||||||
|
@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||||
name: "safe search",
|
name: "safe search",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
err := d.initSecurityServices()
|
defer func() { err = errors.Annotate(err, "filtering: %w") }()
|
||||||
if err != nil {
|
|
||||||
log.Error("filtering: initialize services: %s", err)
|
|
||||||
|
|
||||||
return nil
|
err = d.initSecurityServices()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("initializing services: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c != nil {
|
|
||||||
d.Config = *c
|
d.Config = *c
|
||||||
|
d.filtersMu = &sync.RWMutex{}
|
||||||
|
|
||||||
err = d.prepareRewrites()
|
err = d.prepareRewrites()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("rewrites: preparing: %s", err)
|
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bsvcs := []string{}
|
bsvcs := []string{}
|
||||||
for _, s := range d.BlockedServices {
|
for _, s := range d.BlockedServices {
|
||||||
if !BlockedSvcKnown(s) {
|
if !BlockedSvcKnown(s) {
|
||||||
log.Debug("skipping unknown blocked-service %q", s)
|
log.Debug("skipping unknown blocked-service %q", s)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
bsvcs = append(bsvcs, s)
|
bsvcs = append(bsvcs, s)
|
||||||
|
@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||||
if blockFilters != nil {
|
if blockFilters != nil {
|
||||||
err = d.initFiltering(nil, blockFilters)
|
err = d.initFiltering(nil, blockFilters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Can't initialize filtering subsystem: %s", err)
|
|
||||||
d.Close()
|
d.Close()
|
||||||
return nil
|
|
||||||
|
return nil, fmt.Errorf("initializing filtering subsystem: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return d
|
_ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755)
|
||||||
|
|
||||||
|
d.loadFilters(d.Filters)
|
||||||
|
d.loadFilters(d.WhitelistFilters)
|
||||||
|
|
||||||
|
d.Filters = deduplicateFilters(d.Filters)
|
||||||
|
d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters)
|
||||||
|
|
||||||
|
updateUniqueFilterID(d.Filters)
|
||||||
|
updateUniqueFilterID(d.WhitelistFilters)
|
||||||
|
|
||||||
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start - start the module:
|
// Start - start the module:
|
||||||
|
@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() {
|
||||||
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
||||||
go d.filtersInitializer()
|
go d.filtersInitializer()
|
||||||
|
|
||||||
if d.Config.HTTPRegister != nil { // for tests
|
d.RegisterFilteringHandlers()
|
||||||
d.registerSecurityHandlers()
|
|
||||||
d.registerRewritesHandlers()
|
// Here we should start updating filters,
|
||||||
d.registerBlockedServicesHandlers()
|
// but currently we can't wake up the periodic task to do so.
|
||||||
}
|
// So for now we just start this periodic task from here.
|
||||||
|
go d.periodicallyRefreshFilters()
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,10 +26,6 @@ const (
|
||||||
pcBlocked = "pornhub.com"
|
pcBlocked = "pornhub.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
var setts = Settings{
|
|
||||||
ProtectionEnabled: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helpers.
|
// Helpers.
|
||||||
|
|
||||||
func purgeCaches(d *DNSFilter) {
|
func purgeCaches(d *DNSFilter) {
|
||||||
|
@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
|
||||||
setts = Settings{
|
setts = &Settings{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
}
|
}
|
||||||
|
@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||||
setts.ParentalEnabled = c.ParentalEnabled
|
setts.ParentalEnabled = c.ParentalEnabled
|
||||||
|
} else {
|
||||||
|
// It must not be nil.
|
||||||
|
c = &Config{}
|
||||||
}
|
}
|
||||||
d := New(c, filters)
|
f, err := New(c, filters)
|
||||||
purgeCaches(d)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return d
|
purgeCaches(f)
|
||||||
|
|
||||||
|
return f, setts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||||
require.NoErrorf(t, err, "host %q", hostname)
|
require.NoErrorf(t, err, "host %q", hostname)
|
||||||
|
|
||||||
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
|
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
res, err := d.CheckHost(hostname, qtype, setts)
|
||||||
require.NoErrorf(t, err, "host %q", hostname, err)
|
require.NoErrorf(t, err, "host %q", hostname, err)
|
||||||
require.NotEmpty(t, res.Rules, "host %q", hostname)
|
require.NotEmpty(t, res.Rules, "host %q", hostname)
|
||||||
|
|
||||||
|
@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
|
||||||
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
|
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||||
require.NoErrorf(t, err, "host %q", hostname)
|
require.NoErrorf(t, err, "host %q", hostname)
|
||||||
|
|
||||||
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
|
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
|
||||||
|
@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
filters := []Filter{{
|
filters := []Filter{{
|
||||||
ID: 0, Data: []byte(text),
|
ID: 0, Data: []byte(text),
|
||||||
}}
|
}}
|
||||||
d := newForTest(t, nil, filters)
|
d, setts := newForTest(t, nil, filters)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts)
|
||||||
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
|
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts)
|
||||||
d.checkMatchEmpty(t, "subdomain.google.com")
|
d.checkMatchEmpty(t, "subdomain.google.com", setts)
|
||||||
d.checkMatchEmpty(t, "example.org")
|
d.checkMatchEmpty(t, "example.org", setts)
|
||||||
|
|
||||||
// IPv4 match.
|
// IPv4 match.
|
||||||
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA)
|
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts)
|
||||||
|
|
||||||
// Empty IPv6.
|
// Empty IPv6.
|
||||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
res, err := d.CheckHost("block.com", dns.TypeAAAA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
// IPv6 match.
|
// IPv6 match.
|
||||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts)
|
||||||
|
|
||||||
// Empty IPv4.
|
// Empty IPv4.
|
||||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
res, err = d.CheckHost("ipv6.com", dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
assert.Empty(t, res.Rules[0].IP)
|
assert.Empty(t, res.Rules[0].IP)
|
||||||
|
|
||||||
// Two IPv4, both must be returned.
|
// Two IPv4, both must be returned.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||||
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
|
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
|
||||||
|
|
||||||
// One IPv6 address.
|
// One IPv6 address.
|
||||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeAAAA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) {
|
||||||
aghtest.ReplaceLogWriter(t, logOutput)
|
aghtest.ReplaceLogWriter(t, logOutput)
|
||||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||||
|
|
||||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||||
d.checkMatch(t, sbBlocked)
|
d.checkMatch(t, sbBlocked, setts)
|
||||||
|
|
||||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||||
|
|
||||||
d.checkMatch(t, "test."+sbBlocked)
|
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||||
d.checkMatchEmpty(t, pcBlocked)
|
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||||
|
|
||||||
// Cached result.
|
// Cached result.
|
||||||
d.safeBrowsingServer = "127.0.0.1"
|
d.safeBrowsingServer = "127.0.0.1"
|
||||||
d.checkMatch(t, sbBlocked)
|
d.checkMatch(t, sbBlocked, setts)
|
||||||
d.checkMatchEmpty(t, pcBlocked)
|
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParallelSB(t *testing.T) {
|
func TestParallelSB(t *testing.T) {
|
||||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||||
|
@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) {
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d.checkMatch(t, sbBlocked)
|
d.checkMatch(t, sbBlocked, setts)
|
||||||
d.checkMatch(t, "test."+sbBlocked)
|
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||||
d.checkMatchEmpty(t, pcBlocked)
|
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) {
|
||||||
// Safe Search.
|
// Safe Search.
|
||||||
|
|
||||||
func TestSafeSearch(t *testing.T) {
|
func TestSafeSearch(t *testing.T) {
|
||||||
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
d := newForTest(t, &Config{
|
d, setts := newForTest(t, &Config{
|
||||||
SafeSearchEnabled: true,
|
SafeSearchEnabled: true,
|
||||||
}, nil)
|
}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
"www.yandex.com",
|
"www.yandex.com",
|
||||||
} {
|
} {
|
||||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||||
|
|
||||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
resolver := &aghtest.TestResolver{}
|
resolver := &aghtest.TestResolver{}
|
||||||
d := newForTest(t, &Config{
|
d, setts := newForTest(t, &Config{
|
||||||
SafeSearchEnabled: true,
|
SafeSearchEnabled: true,
|
||||||
CustomResolver: resolver,
|
CustomResolver: resolver,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
"www.google.je",
|
"www.google.je",
|
||||||
} {
|
} {
|
||||||
t.Run(host, func(t *testing.T) {
|
t.Run(host, func(t *testing.T) {
|
||||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
d := newForTest(t, nil, nil)
|
d, setts := newForTest(t, nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
const domain = "yandex.ru"
|
const domain = "yandex.ru"
|
||||||
|
|
||||||
// Check host with disabled safesearch.
|
// Check host with disabled safesearch.
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
|
|
||||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||||
|
|
||||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// For yandex we already know valid IP.
|
// For yandex we already know valid IP.
|
||||||
|
@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||||
|
|
||||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
resolver := &aghtest.TestResolver{}
|
resolver := &aghtest.TestResolver{}
|
||||||
d := newForTest(t, &Config{
|
d, setts := newForTest(t, &Config{
|
||||||
CustomResolver: resolver,
|
CustomResolver: resolver,
|
||||||
}, nil)
|
}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
const domain = "www.google.ru"
|
const domain = "www.google.ru"
|
||||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
|
||||||
require.Empty(t, res.Rules)
|
require.Empty(t, res.Rules)
|
||||||
|
|
||||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
d.resolver = resolver
|
d.resolver = resolver
|
||||||
|
|
||||||
|
@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
|
@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) {
|
||||||
aghtest.ReplaceLogWriter(t, logOutput)
|
aghtest.ReplaceLogWriter(t, logOutput)
|
||||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||||
|
|
||||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||||
d.checkMatch(t, pcBlocked)
|
d.checkMatch(t, pcBlocked, setts)
|
||||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||||
|
|
||||||
d.checkMatch(t, "www."+pcBlocked)
|
d.checkMatch(t, "www."+pcBlocked, setts)
|
||||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
d.checkMatchEmpty(t, "www.yandex.ru", setts)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||||
d.checkMatchEmpty(t, "api.jquery.com")
|
d.checkMatchEmpty(t, "api.jquery.com", setts)
|
||||||
|
|
||||||
// Test cached result.
|
// Test cached result.
|
||||||
d.parentalServer = "127.0.0.1"
|
d.parentalServer = "127.0.0.1"
|
||||||
d.checkMatch(t, pcBlocked)
|
d.checkMatch(t, pcBlocked, setts)
|
||||||
d.checkMatchEmpty(t, "yandex.ru")
|
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filtering.
|
// Filtering.
|
||||||
|
@ -679,10 +680,10 @@ func TestMatching(t *testing.T) {
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||||
d := newForTest(t, nil, filters)
|
d, setts := newForTest(t, nil, filters)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||||
|
@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) {
|
||||||
whiteFilters := []Filter{{
|
whiteFilters := []Filter{{
|
||||||
ID: 0, Data: []byte(whiteRules),
|
ID: 0, Data: []byte(whiteRules),
|
||||||
}}
|
}}
|
||||||
d := newForTest(t, nil, filters)
|
d, setts := newForTest(t, nil, filters)
|
||||||
|
|
||||||
err := d.SetFilters(filters, whiteFilters, false)
|
err := d.SetFilters(filters, whiteFilters, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) {
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
// Matched by white filter.
|
// Matched by white filter.
|
||||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
res, err := d.CheckHost("host1", dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
|
@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) {
|
||||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||||
|
|
||||||
// Not matched by white filter, but matched by block filter.
|
// Not matched by white filter, but matched by block filter.
|
||||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
|
@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientSettings(t *testing.T) {
|
func TestClientSettings(t *testing.T) {
|
||||||
d := newForTest(t,
|
d, setts := newForTest(t,
|
||||||
&Config{
|
&Config{
|
||||||
ParentalEnabled: true,
|
ParentalEnabled: true,
|
||||||
SafeBrowsingEnabled: false,
|
SafeBrowsingEnabled: false,
|
||||||
|
@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
|
r, err := d.CheckHost(tc.host, dns.TypeA, setts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if before {
|
if before {
|
||||||
|
@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) {
|
||||||
t.Run(tc.name, makeTester(tc, tc.before))
|
t.Run(tc.name, makeTester(tc, tc.before))
|
||||||
}
|
}
|
||||||
|
|
||||||
applyClientSettings(&setts)
|
applyClientSettings(setts)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, makeTester(tc, !tc.before))
|
t.Run(tc.name, makeTester(tc, !tc.before))
|
||||||
|
@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) {
|
||||||
// Benchmarks.
|
// Benchmarks.
|
||||||
|
|
||||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
|
|
||||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||||
|
@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
|
|
||||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||||
|
@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSafeSearch(b *testing.B) {
|
func BenchmarkSafeSearch(b *testing.B) {
|
||||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
val, ok := d.SafeSearchDomain("www.google.com")
|
val, ok := d.SafeSearchDomain("www.google.com")
|
||||||
|
@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||||
b.Cleanup(d.Close)
|
b.Cleanup(d.Close)
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
|
|
|
@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||||
// 1. A and AAAA > CNAME
|
// 1. A and AAAA > CNAME
|
||||||
// 2. wildcard > exact
|
// 2. wildcard > exact
|
||||||
// 3. lower level wildcard > higher level wildcard
|
// 3. lower level wildcard > higher level wildcard
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Replace with slices.Sort.
|
||||||
type rewritesSorted []*LegacyRewrite
|
type rewritesSorted []*LegacyRewrite
|
||||||
|
|
||||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
// Len implements the sort.Interface interface for rewritesSorted.
|
||||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||||
|
|
||||||
// Swap implements the sort.Interface interface for legacyRewritesSorted.
|
// Swap implements the sort.Interface interface for rewritesSorted.
|
||||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
|
||||||
// Less implements the sort.Interface interface for legacyRewritesSorted.
|
// Less implements the sort.Interface interface for rewritesSorted.
|
||||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||||
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
|
ith, jth := a[i], a[j]
|
||||||
|
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
|
||||||
return true
|
return true
|
||||||
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
|
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if isWildcard(a[i].Domain) {
|
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
|
||||||
if !isWildcard(a[j].Domain) {
|
return jw
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if isWildcard(a[j].Domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Both are wildcards.
|
// Both are either wildcards or not.
|
||||||
return len(a[i].Domain) > len(a[j].Domain)
|
return len(ith.Domain) > len(jth.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||||
|
@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||||
|
|
||||||
d.Config.ConfigModified()
|
d.Config.ConfigModified()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) registerRewritesHandlers() {
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||||
|
|
||||||
func TestRewrites(t *testing.T) {
|
func TestRewrites(t *testing.T) {
|
||||||
d := newForTest(t, nil, nil)
|
d, _ := newForTest(t, nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
d.Rewrites = []*LegacyRewrite{{
|
d.Rewrites = []*LegacyRewrite{{
|
||||||
|
@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesLevels(t *testing.T) {
|
func TestRewritesLevels(t *testing.T) {
|
||||||
d := newForTest(t, nil, nil)
|
d, _ := newForTest(t, nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// Exact host, wildcard L2, wildcard L3.
|
// Exact host, wildcard L2, wildcard L3.
|
||||||
d.Rewrites = []*LegacyRewrite{{
|
d.Rewrites = []*LegacyRewrite{{
|
||||||
|
@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||||
d := newForTest(t, nil, nil)
|
d, _ := newForTest(t, nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// Wildcard and exception for a sub-domain.
|
// Wildcard and exception for a sub-domain.
|
||||||
d.Rewrites = []*LegacyRewrite{{
|
d.Rewrites = []*LegacyRewrite{{
|
||||||
|
@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRewritesExceptionIP(t *testing.T) {
|
func TestRewritesExceptionIP(t *testing.T) {
|
||||||
d := newForTest(t, nil, nil)
|
d, _ := newForTest(t, nil, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
// Exception for AAAA record.
|
// Exception for AAAA record.
|
||||||
d.Rewrites = []*LegacyRewrite{{
|
d.Rewrites = []*LegacyRewrite{{
|
||||||
|
|
|
@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) registerSecurityHandlers() {
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
|
||||||
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
|
||||||
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
|
||||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
|
||||||
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
|
||||||
}
|
|
||||||
|
|
|
@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
ups := aghtest.NewErrorUpstream()
|
ups := aghtest.NewErrorUpstream()
|
||||||
|
@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSBPC(t *testing.T) {
|
func TestSBPC(t *testing.T) {
|
||||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||||
t.Cleanup(d.Close)
|
t.Cleanup(d.Close)
|
||||||
|
|
||||||
const hostname = "example.org"
|
const hostname = "example.org"
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -23,10 +22,9 @@ import (
|
||||||
yaml "gopkg.in/yaml.v3"
|
yaml "gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// dataDir is the name of a directory under the working one to store some
|
||||||
dataDir = "data" // data storage
|
// persistent data.
|
||||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
const dataDir = "data"
|
||||||
)
|
|
||||||
|
|
||||||
// logSettings are the logging settings part of the configuration file.
|
// logSettings are the logging settings part of the configuration file.
|
||||||
//
|
//
|
||||||
|
@ -108,8 +106,15 @@ type configuration struct {
|
||||||
DNS dnsConfig `yaml:"dns"`
|
DNS dnsConfig `yaml:"dns"`
|
||||||
TLS tlsConfigSettings `yaml:"tls"`
|
TLS tlsConfigSettings `yaml:"tls"`
|
||||||
|
|
||||||
Filters []filter `yaml:"filters"`
|
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
||||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
// config used in the filtering module at the startup. Afterwards it's
|
||||||
|
// cloned from the filtering module back here.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Move all the filtering configuration fields into the
|
||||||
|
// only configuration subsection covering the changes with a single
|
||||||
|
// migration.
|
||||||
|
Filters []filtering.FilterYAML `yaml:"filters"`
|
||||||
|
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
|
||||||
UserRules []string `yaml:"user_rules"`
|
UserRules []string `yaml:"user_rules"`
|
||||||
|
|
||||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||||
|
@ -145,9 +150,7 @@ type dnsConfig struct {
|
||||||
|
|
||||||
dnsforward.FilteringConfig `yaml:",inline"`
|
dnsforward.FilteringConfig `yaml:",inline"`
|
||||||
|
|
||||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
DnsfilterConf *filtering.Config `yaml:",inline"`
|
||||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
|
||||||
DnsfilterConf filtering.Config `yaml:",inline"`
|
|
||||||
|
|
||||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||||
|
@ -198,10 +201,15 @@ var config = &configuration{
|
||||||
BindHost: net.IP{0, 0, 0, 0},
|
BindHost: net.IP{0, 0, 0, 0},
|
||||||
AuthAttempts: 5,
|
AuthAttempts: 5,
|
||||||
AuthBlockMin: 15,
|
AuthBlockMin: 15,
|
||||||
|
WebSessionTTLHours: 30 * 24,
|
||||||
DNS: dnsConfig{
|
DNS: dnsConfig{
|
||||||
BindHosts: []net.IP{{0, 0, 0, 0}},
|
BindHosts: []net.IP{{0, 0, 0, 0}},
|
||||||
Port: defaultPortDNS,
|
Port: defaultPortDNS,
|
||||||
StatsInterval: 1,
|
StatsInterval: 1,
|
||||||
|
QueryLogEnabled: true,
|
||||||
|
QueryLogFileEnabled: true,
|
||||||
|
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
|
||||||
|
QueryLogMemSize: 1000,
|
||||||
FilteringConfig: dnsforward.FilteringConfig{
|
FilteringConfig: dnsforward.FilteringConfig{
|
||||||
ProtectionEnabled: true, // whether or not use any of filtering features
|
ProtectionEnabled: true, // whether or not use any of filtering features
|
||||||
BlockingMode: dnsforward.BlockingModeDefault,
|
BlockingMode: dnsforward.BlockingModeDefault,
|
||||||
|
@ -222,8 +230,14 @@ var config = &configuration{
|
||||||
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
||||||
MaxGoroutines: 300,
|
MaxGoroutines: 300,
|
||||||
},
|
},
|
||||||
FilteringEnabled: true, // whether or not use filter lists
|
DnsfilterConf: &filtering.Config{
|
||||||
|
SafeBrowsingCacheSize: 1 * 1024 * 1024,
|
||||||
|
SafeSearchCacheSize: 1 * 1024 * 1024,
|
||||||
|
ParentalCacheSize: 1 * 1024 * 1024,
|
||||||
|
CacheTime: 30,
|
||||||
|
FilteringEnabled: true,
|
||||||
FiltersUpdateIntervalHours: 24,
|
FiltersUpdateIntervalHours: 24,
|
||||||
|
},
|
||||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
},
|
},
|
||||||
|
@ -232,8 +246,26 @@ var config = &configuration{
|
||||||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||||
PortDNSOverQUIC: defaultPortQUIC,
|
PortDNSOverQUIC: defaultPortQUIC,
|
||||||
},
|
},
|
||||||
|
Filters: []filtering.FilterYAML{{
|
||||||
|
Filter: filtering.Filter{ID: 1},
|
||||||
|
Enabled: true,
|
||||||
|
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
|
||||||
|
Name: "AdGuard DNS filter",
|
||||||
|
}, {
|
||||||
|
Filter: filtering.Filter{ID: 2},
|
||||||
|
Enabled: false,
|
||||||
|
URL: "https://adaway.org/hosts.txt",
|
||||||
|
Name: "AdAway Default Blocklist",
|
||||||
|
}},
|
||||||
DHCP: &dhcpd.ServerConfig{
|
DHCP: &dhcpd.ServerConfig{
|
||||||
LocalDomainName: "lan",
|
LocalDomainName: "lan",
|
||||||
|
Conf4: dhcpd.V4ServerConf{
|
||||||
|
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||||
|
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
|
||||||
|
},
|
||||||
|
Conf6: dhcpd.V6ServerConf{
|
||||||
|
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Clients: &clientsConfig{
|
Clients: &clientsConfig{
|
||||||
Sources: &clientSourcesConf{
|
Sources: &clientSourcesConf{
|
||||||
|
@ -255,31 +287,6 @@ var config = &configuration{
|
||||||
SchemaVersion: currentSchemaVersion,
|
SchemaVersion: currentSchemaVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
// initConfig initializes default configuration for the current OS&ARCH
|
|
||||||
func initConfig() {
|
|
||||||
config.WebSessionTTLHours = 30 * 24
|
|
||||||
|
|
||||||
config.DNS.QueryLogEnabled = true
|
|
||||||
config.DNS.QueryLogFileEnabled = true
|
|
||||||
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
|
|
||||||
config.DNS.QueryLogMemSize = 1000
|
|
||||||
|
|
||||||
config.DNS.CacheSize = 4 * 1024 * 1024
|
|
||||||
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
|
|
||||||
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
|
|
||||||
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
|
|
||||||
config.DNS.DnsfilterConf.CacheTime = 30
|
|
||||||
config.Filters = defaultFilters()
|
|
||||||
|
|
||||||
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
|
||||||
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
|
|
||||||
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
|
||||||
|
|
||||||
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
|
|
||||||
config.BetaBindPort = 3001
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConfigFilename returns path to the current config file
|
// getConfigFilename returns path to the current config file
|
||||||
func (c *configuration) getConfigFilename() string {
|
func (c *configuration) getConfigFilename() string {
|
||||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
||||||
|
@ -348,8 +355,8 @@ func parseConfig() (err error) {
|
||||||
return fmt.Errorf("validating udp ports: %w", err)
|
return fmt.Errorf("validating udp ports: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
|
||||||
config.DNS.FiltersUpdateIntervalHours = 24
|
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||||
|
@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
|
||||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.dnsFilter != nil {
|
if Context.filters != nil {
|
||||||
c := filtering.Config{}
|
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
|
||||||
Context.dnsFilter.WriteDiskConfig(&c)
|
config.Filters = config.DNS.DnsfilterConf.Filters
|
||||||
config.DNS.DnsfilterConf = c
|
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
|
||||||
|
config.UserRules = config.DNS.DnsfilterConf.UserRules
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := Context.dnsServer; s != nil {
|
if s := Context.dnsServer; s != nil {
|
||||||
|
|
|
@ -291,7 +291,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
httpsURL := &url.URL{
|
httpsURL := &url.URL{
|
||||||
Scheme: schemeHTTPS,
|
Scheme: aghhttp.SchemeHTTPS,
|
||||||
Host: hostPort,
|
Host: hostPort,
|
||||||
Path: r.URL.Path,
|
Path: r.URL.Path,
|
||||||
RawQuery: r.URL.RawQuery,
|
RawQuery: r.URL.RawQuery,
|
||||||
|
@ -307,7 +307,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||||
//
|
//
|
||||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
||||||
originURL := &url.URL{
|
originURL := &url.URL{
|
||||||
Scheme: schemeHTTP,
|
Scheme: aghhttp.SchemeHTTP,
|
||||||
Host: r.Host,
|
Host: r.Host,
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
||||||
|
|
|
@ -31,7 +31,10 @@ const (
|
||||||
|
|
||||||
// Called by other modules when configuration is changed
|
// Called by other modules when configuration is changed
|
||||||
func onConfigModified() {
|
func onConfigModified() {
|
||||||
_ = config.write()
|
err := config.write()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("writing config: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDNSServer creates an instance of the dnsforward.Server
|
// initDNSServer creates an instance of the dnsforward.Server
|
||||||
|
@ -71,11 +74,11 @@ func initDNSServer() (err error) {
|
||||||
}
|
}
|
||||||
Context.queryLog = querylog.New(conf)
|
Context.queryLog = querylog.New(conf)
|
||||||
|
|
||||||
filterConf := config.DNS.DnsfilterConf
|
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||||
filterConf.EtcHosts = Context.etcHosts
|
if err != nil {
|
||||||
filterConf.ConfigModified = onConfigModified
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
filterConf.HTTPRegister = httpRegister
|
return err
|
||||||
Context.dnsFilter = filtering.New(&filterConf, nil)
|
}
|
||||||
|
|
||||||
var privateNets netutil.SubnetSet
|
var privateNets netutil.SubnetSet
|
||||||
switch len(config.DNS.PrivateNets) {
|
switch len(config.DNS.PrivateNets) {
|
||||||
|
@ -83,13 +86,10 @@ func initDNSServer() (err error) {
|
||||||
// Use an optimized locally-served matcher.
|
// Use an optimized locally-served matcher.
|
||||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||||
case 1:
|
case 1:
|
||||||
var n *net.IPNet
|
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||||
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
privateNets = n
|
|
||||||
default:
|
default:
|
||||||
var nets []*net.IPNet
|
var nets []*net.IPNet
|
||||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||||
|
@ -101,15 +101,13 @@ func initDNSServer() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p := dnsforward.DNSCreateParams{
|
p := dnsforward.DNSCreateParams{
|
||||||
DNSFilter: Context.dnsFilter,
|
DNSFilter: Context.filters,
|
||||||
Stats: Context.stats,
|
Stats: Context.stats,
|
||||||
QueryLog: Context.queryLog,
|
QueryLog: Context.queryLog,
|
||||||
PrivateNets: privateNets,
|
PrivateNets: privateNets,
|
||||||
Anonymizer: anonymizer,
|
Anonymizer: anonymizer,
|
||||||
LocalDomain: config.DHCP.LocalDomainName,
|
LocalDomain: config.DHCP.LocalDomainName,
|
||||||
}
|
DHCPServer: Context.dhcpServer,
|
||||||
if Context.dhcpServer != nil {
|
|
||||||
p.DHCPServer = Context.dhcpServer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||||
|
@ -143,7 +141,6 @@ func initDNSServer() (err error) {
|
||||||
Context.whois = initWHOIS(&Context.clients)
|
Context.whois = initWHOIS(&Context.clients)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters.Init()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -335,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||||
// applyAdditionalFiltering adds additional client information and settings if
|
// applyAdditionalFiltering adds additional client information and settings if
|
||||||
// the client has them.
|
// the client has them.
|
||||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
// pref is a prefix for logging messages around the scope.
|
||||||
|
const pref = "applying filters"
|
||||||
|
|
||||||
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
|
Context.filters.ApplyBlockedServices(setts, nil)
|
||||||
|
|
||||||
|
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||||
|
|
||||||
if clientIP == nil {
|
if clientIP == nil {
|
||||||
return
|
return
|
||||||
|
@ -349,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||||
if !ok {
|
if !ok {
|
||||||
c, ok = Context.clients.Find(clientIP.String())
|
c, ok = Context.clients.Find(clientIP.String())
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
|
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
|
log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
|
||||||
|
|
||||||
if c.UseOwnBlockedServices {
|
if c.UseOwnBlockedServices {
|
||||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
|
||||||
}
|
}
|
||||||
|
|
||||||
setts.ClientName = c.Name
|
setts.ClientName = c.Name
|
||||||
|
@ -381,7 +381,7 @@ func startDNSServer() error {
|
||||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||||
}
|
}
|
||||||
|
|
||||||
enableFiltersLocked(false)
|
Context.filters.EnableFilters(false)
|
||||||
|
|
||||||
Context.clients.Start()
|
Context.clients.Start()
|
||||||
|
|
||||||
|
@ -390,7 +390,6 @@ func startDNSServer() error {
|
||||||
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.dnsFilter.Start()
|
|
||||||
Context.filters.Start()
|
Context.filters.Start()
|
||||||
Context.stats.Start()
|
Context.stats.Start()
|
||||||
Context.queryLog.Start()
|
Context.queryLog.Start()
|
||||||
|
@ -449,10 +448,7 @@ func closeDNSServer() {
|
||||||
Context.dnsServer = nil
|
Context.dnsServer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.dnsFilter != nil {
|
Context.filters.Close()
|
||||||
Context.dnsFilter.Close()
|
|
||||||
Context.dnsFilter = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if Context.stats != nil {
|
if Context.stats != nil {
|
||||||
err := Context.stats.Close()
|
err := Context.stats.Close()
|
||||||
|
@ -469,7 +465,5 @@ func closeDNSServer() {
|
||||||
Context.queryLog = nil
|
Context.queryLog = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters.Close()
|
log.Debug("all dns modules are closed")
|
||||||
|
|
||||||
log.Debug("Closed all DNS modules")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||||
|
@ -33,6 +34,7 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,10 +54,9 @@ type homeContext struct {
|
||||||
dnsServer *dnsforward.Server // DNS module
|
dnsServer *dnsforward.Server // DNS module
|
||||||
rdns *RDNS // rDNS module
|
rdns *RDNS // rDNS module
|
||||||
whois *WHOIS // WHOIS module
|
whois *WHOIS // WHOIS module
|
||||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
|
||||||
dhcpServer dhcpd.Interface // DHCP module
|
dhcpServer dhcpd.Interface // DHCP module
|
||||||
auth *Auth // HTTP authentication module
|
auth *Auth // HTTP authentication module
|
||||||
filters Filtering // DNS filtering module
|
filters *filtering.DNSFilter // DNS filtering module
|
||||||
web *Web // Web (HTTP, HTTPS) module
|
web *Web // Web (HTTP, HTTPS) module
|
||||||
tls *TLSMod // TLS module
|
tls *TLSMod // TLS module
|
||||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||||
|
@ -140,7 +141,12 @@ func setupContext(args options) {
|
||||||
checkPermissions()
|
checkPermissions()
|
||||||
}
|
}
|
||||||
|
|
||||||
initConfig()
|
switch version.Channel() {
|
||||||
|
case version.ChannelEdge, version.ChannelDevelopment:
|
||||||
|
config.BetaBindPort = 3001
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
|
||||||
Context.tlsRoots = LoadSystemRootCAs()
|
Context.tlsRoots = LoadSystemRootCAs()
|
||||||
Context.transport = &http.Transport{
|
Context.transport = &http.Transport{
|
||||||
|
@ -265,6 +271,14 @@ func setupHostsContainer() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupConfig(args options) (err error) {
|
func setupConfig(args options) (err error) {
|
||||||
|
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
|
||||||
|
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
|
||||||
|
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
|
||||||
|
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
|
||||||
|
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
|
||||||
|
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||||
|
config.DNS.DnsfilterConf.HTTPClient = Context.client
|
||||||
|
|
||||||
config.DHCP.WorkDir = Context.workDir
|
config.DHCP.WorkDir = Context.workDir
|
||||||
config.DHCP.HTTPRegister = httpRegister
|
config.DHCP.HTTPRegister = httpRegister
|
||||||
config.DHCP.ConfigModified = onConfigModified
|
config.DHCP.ConfigModified = onConfigModified
|
||||||
|
@ -384,8 +398,6 @@ func fatalOnError(err error) {
|
||||||
|
|
||||||
// run configures and starts AdGuard Home.
|
// run configures and starts AdGuard Home.
|
||||||
func run(args options, clientBuildFS fs.FS) {
|
func run(args options, clientBuildFS fs.FS) {
|
||||||
var err error
|
|
||||||
|
|
||||||
// configure config filename
|
// configure config filename
|
||||||
initConfigFilename(args)
|
initConfigFilename(args)
|
||||||
|
|
||||||
|
@ -404,7 +416,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||||
|
|
||||||
setupContext(args)
|
setupContext(args)
|
||||||
|
|
||||||
err = configureOS(config)
|
err := configureOS(config)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
|
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
|
||||||
|
@ -763,12 +775,12 @@ func printHTTPAddresses(proto string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
port := config.BindPort
|
port := config.BindPort
|
||||||
if proto == schemeHTTPS {
|
if proto == aghhttp.SchemeHTTPS {
|
||||||
port = tlsConf.PortHTTPS
|
port = tlsConf.PortHTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||||
if proto == schemeHTTPS && tlsConf.ServerName != "" {
|
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
|
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||||
case dnsProtoHTTPS:
|
case dnsProtoHTTPS:
|
||||||
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: schemeHTTPS,
|
Scheme: aghhttp.SchemeHTTPS,
|
||||||
Host: d.ServerName,
|
Host: d.ServerName,
|
||||||
Path: path.Join("/dns-query", clientID),
|
Path: path.Join("/dns-query", clientID),
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -277,7 +278,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
|
||||||
There are a few more things that must be configured before you can use it.
|
There are a few more things that must be configured before you can use it.
|
||||||
Click on the link below and follow the Installation Wizard steps to finish setup.
|
Click on the link below and follow the Installation Wizard steps to finish setup.
|
||||||
AdGuard Home is now available at the following addresses:`)
|
AdGuard Home is now available at the following addresses:`)
|
||||||
printHTTPAddresses(schemeHTTP)
|
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -160,7 +161,7 @@ func assertEqualExcept(t *testing.T, oldConf, newConf yobj, oldKeys, newKeys []s
|
||||||
}
|
}
|
||||||
|
|
||||||
func testDiskConf(schemaVersion int) (diskConf yobj) {
|
func testDiskConf(schemaVersion int) (diskConf yobj) {
|
||||||
filters := []filter{{
|
filters := []filtering.FilterYAML{{
|
||||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||||
Name: "Latvian filter",
|
Name: "Latvian filter",
|
||||||
RulesCount: 100,
|
RulesCount: 100,
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -19,12 +20,6 @@ import (
|
||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTP scheme constants.
|
|
||||||
const (
|
|
||||||
schemeHTTP = "http"
|
|
||||||
schemeHTTPS = "https"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// readTimeout is the maximum duration for reading the entire request,
|
// readTimeout is the maximum duration for reading the entire request,
|
||||||
// including the body.
|
// including the body.
|
||||||
|
@ -166,7 +161,7 @@ func (web *Web) Start() {
|
||||||
|
|
||||||
// this loop is used as an ability to change listening host and/or port
|
// this loop is used as an ability to change listening host and/or port
|
||||||
for !web.httpsServer.shutdown {
|
for !web.httpsServer.shutdown {
|
||||||
printHTTPAddresses(schemeHTTP)
|
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||||
errs := make(chan error, 2)
|
errs := make(chan error, 2)
|
||||||
|
|
||||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||||
|
@ -286,7 +281,7 @@ func (web *Web) tlsServerLoop() {
|
||||||
WriteTimeout: web.conf.WriteTimeout,
|
WriteTimeout: web.conf.WriteTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
printHTTPAddresses(schemeHTTPS)
|
printHTTPAddresses(aghhttp.SchemeHTTPS)
|
||||||
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
||||||
if err != http.ErrServerClosed {
|
if err != http.ErrServerClosed {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
|
|
Loading…
Reference in New Issue