diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 3e0764d2..aec5855e 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -8,6 +8,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" @@ -28,7 +29,7 @@ func initBlockedServices() { for i, s := range blockedServices { netRules := make([]*rules.NetworkRule, 0, len(s.Rules)) for _, text := range s.Rules { - rule, err := rules.NewNetworkRule(text, BlockedSvcsListID) + rule, err := rules.NewNetworkRule(text, rulelist.URLFilterIDBlockedService) if err != nil { log.Error("parsing blocked service %q rule %q: %s", s.ID, text, err) diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 65ef1b7a..ce5a7e13 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -608,7 +608,7 @@ func (d *DNSFilter) EnableFilters(async bool) { func (d *DNSFilter) enableFiltersLocked(async bool) { filters := make([]Filter, 1, len(d.conf.Filters)+len(d.conf.WhitelistFilters)+1) filters[0] = Filter{ - ID: CustomListID, + ID: rulelist.URLFilterIDCustom, Data: []byte(strings.Join(d.conf.UserRules, "\n")), } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index b145ed3b..2203fe85 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -32,19 +32,6 @@ import ( "github.com/miekg/dns" ) -// The IDs of built-in filter lists. -// -// Keep in sync with client/src/helpers/constants.js. -// TODO(d.kolyshev): Add RewritesListID and don't forget to keep in sync. -const ( - CustomListID = -iota - SysHostsListID - BlockedSvcsListID - ParentalListID - SafeBrowsingListID - SafeSearchListID -) - // ServiceEntry - blocked service array element type ServiceEntry struct { Name string @@ -1139,7 +1126,7 @@ func (d *DNSFilter) checkSafeBrowsing( res = Result{ Rules: []*ResultRule{{ Text: "adguard-malware-shavar", - FilterListID: SafeBrowsingListID, + FilterListID: rulelist.URLFilterIDSafeBrowsing, }}, Reason: FilteredSafeBrowsing, IsFiltered: true, @@ -1171,7 +1158,7 @@ func (d *DNSFilter) checkParental( res = Result{ Rules: []*ResultRule{{ Text: "parental CATEGORY_BLACKLISTED", - FilterListID: ParentalListID, + FilterListID: rulelist.URLFilterIDParentalControl, }}, Reason: FilteredParental, IsFiltered: true, diff --git a/internal/filtering/hosts.go b/internal/filtering/hosts.go index 79cb69ac..4943b1af 100644 --- a/internal/filtering/hosts.go +++ b/internal/filtering/hosts.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -66,7 +67,7 @@ func hostsRewrites( vals = append(vals, name) rls = append(rls, &ResultRule{ Text: fmt.Sprintf("%s %s", addr, name), - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }) } @@ -84,7 +85,7 @@ func hostsRewrites( } rls = append(rls, &ResultRule{ Text: fmt.Sprintf("%s %s", addr, host), - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }) } diff --git a/internal/filtering/hosts_test.go b/internal/filtering/hosts_test.go index 5ea7eff3..e94603a0 100644 --- a/internal/filtering/hosts_test.go +++ b/internal/filtering/hosts_test.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -71,7 +72,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeA, wantRules: []*ResultRule{{ Text: "1.2.3.4 v4.host.example", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{addrv4}, }, { @@ -80,7 +81,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeAAAA, wantRules: []*ResultRule{{ Text: "::1 v6.host.example", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{addrv6}, }, { @@ -89,7 +90,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeAAAA, wantRules: []*ResultRule{{ Text: "::ffff:1.2.3.4 mapped.host.example", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{addrMapped}, }, { @@ -98,7 +99,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypePTR, wantRules: []*ResultRule{{ Text: "1.2.3.4 v4.host.example", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{"v4.host.example"}, }, { @@ -107,7 +108,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypePTR, wantRules: []*ResultRule{{ Text: "::ffff:1.2.3.4 mapped.host.example", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{"mapped.host.example"}, }, { @@ -134,7 +135,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeAAAA, wantRules: []*ResultRule{{ Text: fmt.Sprintf("%s v4.host.example", addrv4), - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: nil, }, { @@ -143,7 +144,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeA, wantRules: []*ResultRule{{ Text: fmt.Sprintf("%s v6.host.example", addrv6), - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: nil, }, { @@ -164,7 +165,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeA, wantRules: []*ResultRule{{ Text: "4.3.2.1 v4.host.with-dup", - FilterListID: SysHostsListID, + FilterListID: rulelist.URLFilterIDEtcHosts, }}, wantResps: []rules.RRValue{addrv4Dup}, }} diff --git a/internal/filtering/rulelist/engine.go b/internal/filtering/rulelist/engine.go new file mode 100644 index 00000000..65e488ce --- /dev/null +++ b/internal/filtering/rulelist/engine.go @@ -0,0 +1,254 @@ +package rulelist + +import ( + "context" + "fmt" + "net/http" + "sync" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" + "github.com/c2h5oh/datasize" +) + +// Engine is a single DNS filter based on one or more rule lists. This +// structure contains the filtering engine combining several rule lists. +// +// TODO(a.garipov): Merge with [TextEngine] in some way? +type Engine struct { + // mu protects engine and storage. + // + // TODO(a.garipov): See if anything else should be protected. + mu *sync.RWMutex + + // engine is the filtering engine. + engine *urlfilter.DNSEngine + + // storage is the filtering-rule storage. It is saved here to close it. + storage *filterlist.RuleStorage + + // name is the human-readable name of the engine, like "allowed", "blocked", + // or "custom". + name string + + // filters is the data about rule filters in this engine. + filters []*Filter +} + +// EngineConfig is the configuration for rule-list filtering engines created by +// combining refreshable filters. +type EngineConfig struct { + // Name is the human-readable name of this engine, like "allowed", + // "blocked", or "custom". + Name string + + // Filters is the data about rule lists in this engine. There must be no + // other references to the elements of this slice. + Filters []*Filter +} + +// NewEngine returns a new rule-list filtering engine. The engine is not +// refreshed, so a refresh should be performed before use. +func NewEngine(c *EngineConfig) (e *Engine) { + return &Engine{ + mu: &sync.RWMutex{}, + name: c.Name, + filters: c.Filters, + } +} + +// Close closes the underlying rule-list engine as well as the rule lists. +func (e *Engine) Close() (err error) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.storage == nil { + return nil + } + + err = e.storage.Close() + if err != nil { + return fmt.Errorf("closing engine %q: %w", e.name, err) + } + + return nil +} + +// FilterRequest returns the result of filtering req using the DNS filtering +// engine. +func (e *Engine) FilterRequest( + req *urlfilter.DNSRequest, +) (res *urlfilter.DNSResult, hasMatched bool) { + return e.currentEngine().MatchRequest(req) +} + +// currentEngine returns the current filtering engine. +func (e *Engine) currentEngine() (enging *urlfilter.DNSEngine) { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.engine +} + +// Refresh updates all rule lists in e. ctx is used for cancellation. +// parseBuf, cli, cacheDir, and maxSize are used for updates of rule-list +// filters; see [Filter.Refresh]. +// +// TODO(a.garipov): Unexport and test in an internal test or through enigne +// tests. +func (e *Engine) Refresh( + ctx context.Context, + parseBuf []byte, + cli *http.Client, + cacheDir string, + maxSize datasize.ByteSize, +) (err error) { + defer func() { err = errors.Annotate(err, "updating engine %q: %w", e.name) }() + + var filtersToRefresh []*Filter + for _, f := range e.filters { + if f.enabled { + filtersToRefresh = append(filtersToRefresh, f) + } + } + + if len(filtersToRefresh) == 0 { + log.Info("filtering: updating engine %q: no rule-list filters", e.name) + + return nil + } + + engRefr := &engineRefresh{ + httpCli: cli, + cacheDir: cacheDir, + engineName: e.name, + parseBuf: parseBuf, + maxSize: maxSize, + } + + ruleLists, errs := engRefr.process(ctx, e.filters) + if isOneTimeoutError(errs) { + // Don't wrap the error since it's informative enough as is. + return err + } + + storage, err := filterlist.NewRuleStorage(ruleLists) + if err != nil { + errs = append(errs, fmt.Errorf("creating rule storage: %w", err)) + + return errors.Join(errs...) + } + + e.resetStorage(storage) + + return errors.Join(errs...) +} + +// resetStorage sets e.storage and e.engine and closes the previous storage. +// Errors from closing the previous storage are logged. +func (e *Engine) resetStorage(storage *filterlist.RuleStorage) { + e.mu.Lock() + defer e.mu.Unlock() + + prevStorage := e.storage + e.storage, e.engine = storage, urlfilter.NewDNSEngine(storage) + + if prevStorage == nil { + return + } + + err := prevStorage.Close() + if err != nil { + log.Error("filtering: engine %q: closing old storage: %s", e.name, err) + } +} + +// isOneTimeoutError returns true if the sole error in errs is either +// [context.Canceled] or [context.DeadlineExceeded]. +func isOneTimeoutError(errs []error) (ok bool) { + if len(errs) != 1 { + return false + } + + err := errs[0] + + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +} + +// engineRefresh represents a single ongoing engine refresh. +type engineRefresh struct { + httpCli *http.Client + cacheDir string + engineName string + parseBuf []byte + maxSize datasize.ByteSize +} + +// process runs updates of all given rule-list filters. All errors are logged +// as they appear, since the update can take a significant amount of time. +// errs contains all errors that happened during the update, unless the context +// is canceled or its deadline is reached, in which case errs will only contain +// a single timeout error. +// +// TODO(a.garipov): Think of a better way to communicate the timeout condition? +func (r *engineRefresh) process( + ctx context.Context, + filters []*Filter, +) (ruleLists []filterlist.RuleList, errs []error) { + ruleLists = make([]filterlist.RuleList, 0, len(filters)) + for i, f := range filters { + select { + case <-ctx.Done(): + return nil, []error{fmt.Errorf("timeout after updating %d filters: %w", i, ctx.Err())} + default: + // Go on. + } + + err := r.processFilter(ctx, f) + if err == nil { + ruleLists = append(ruleLists, f.ruleList) + + continue + } + + errs = append(errs, err) + + // Also log immediately, since the update can take a lot of time. + log.Error( + "filtering: updating engine %q: rule list %s from url %q: %s\n", + r.engineName, + f.uid, + f.url, + err, + ) + } + + return ruleLists, errs +} + +// processFilter runs an update of a single rule-list filter. +func (r *engineRefresh) processFilter(ctx context.Context, f *Filter) (err error) { + prevChecksum := f.checksum + parseRes, err := f.Refresh(ctx, r.parseBuf, r.httpCli, r.cacheDir, r.maxSize) + if err != nil { + return fmt.Errorf("updating %s: %w", f.uid, err) + } + + if prevChecksum == parseRes.Checksum { + log.Info("filtering: engine %q: filter %q: no change", r.engineName, f.uid) + + return nil + } + + log.Info( + "filtering: updated engine %q: filter %q: %d bytes, %d rules", + r.engineName, + f.uid, + parseRes.BytesWritten, + parseRes.RulesCount, + ) + + return nil +} diff --git a/internal/filtering/rulelist/engine_test.go b/internal/filtering/rulelist/engine_test.go new file mode 100644 index 00000000..81ab8bf8 --- /dev/null +++ b/internal/filtering/rulelist/engine_test.go @@ -0,0 +1,63 @@ +package rulelist_test + +import ( + "context" + "net/http" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/urlfilter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEngine_Refresh(t *testing.T) { + cacheDir := t.TempDir() + + fileURL, srvURL := newFilterLocations(t, cacheDir, testRuleTextBlocked, testRuleTextBlocked2) + + fileFlt := newFilter(t, fileURL, "File Filter") + httpFlt := newFilter(t, srvURL, "HTTP Filter") + + eng := rulelist.NewEngine(&rulelist.EngineConfig{ + Name: "Engine", + Filters: []*rulelist.Filter{fileFlt, httpFlt}, + }) + require.NotNil(t, eng) + testutil.CleanupAndRequireSuccess(t, eng.Close) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + buf := make([]byte, rulelist.DefaultRuleBufSize) + cli := &http.Client{ + Timeout: testTimeout, + } + + err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) + require.NoError(t, err) + + fltReq := &urlfilter.DNSRequest{ + Hostname: "blocked.example", + Answer: false, + DNSType: dns.TypeA, + } + + fltRes, hasMatched := eng.FilterRequest(fltReq) + assert.True(t, hasMatched) + + require.NotNil(t, fltRes) + + fltReq = &urlfilter.DNSRequest{ + Hostname: "blocked-2.example", + Answer: false, + DNSType: dns.TypeA, + } + + fltRes, hasMatched = eng.FilterRequest(fltReq) + assert.True(t, hasMatched) + + require.NotNil(t, fltRes) +} diff --git a/internal/filtering/rulelist/filter.go b/internal/filtering/rulelist/filter.go index 278eef5c..5f3fa6be 100644 --- a/internal/filtering/rulelist/filter.go +++ b/internal/filtering/rulelist/filter.go @@ -14,7 +14,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghrenameio" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/ioutil" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/filterlist" "github.com/c2h5oh/datasize" ) @@ -52,8 +51,6 @@ type Filter struct { checksum uint32 // enabled, if true, means that this rule-list filter is used for filtering. - // - // TODO(a.garipov): Take into account. enabled bool } @@ -106,6 +103,11 @@ func NewFilter(c *FilterConfig) (f *Filter, err error) { // Refresh updates the data in the rule-list filter. parseBuf is the initial // buffer used to parse information from the data. cli and maxSize are only // used when f is a URL-based list. +// +// TODO(a.garipov): Unexport and test in an internal test or through enigne +// tests. +// +// TODO(a.garipov): Consider not returning parseRes. func (f *Filter) Refresh( ctx context.Context, parseBuf []byte, @@ -300,39 +302,3 @@ func (f *Filter) Close() (err error) { return f.ruleList.Close() } - -// filterUpdate represents a single ongoing rule-list filter update. -// -//lint:ignore U1000 TODO(a.garipov): Use. -type filterUpdate struct { - httpCli *http.Client - cacheDir string - name string - parseBuf []byte - maxSize datasize.ByteSize -} - -// process runs an update of a single rule-list. -func (u *filterUpdate) process(ctx context.Context, f *Filter) (err error) { - prevChecksum := f.checksum - parseRes, err := f.Refresh(ctx, u.parseBuf, u.httpCli, u.cacheDir, u.maxSize) - if err != nil { - return fmt.Errorf("updating %s: %w", f.uid, err) - } - - if prevChecksum == parseRes.Checksum { - log.Info("filtering: filter %q: filter %q: no change", u.name, f.uid) - - return nil - } - - log.Info( - "filtering: updated filter %q: filter %q: %d bytes, %d rules", - u.name, - f.uid, - parseRes.BytesWritten, - parseRes.RulesCount, - ) - - return nil -} diff --git a/internal/filtering/rulelist/filter_test.go b/internal/filtering/rulelist/filter_test.go index 93cd6e9c..05c1274c 100644 --- a/internal/filtering/rulelist/filter_test.go +++ b/internal/filtering/rulelist/filter_test.go @@ -2,9 +2,7 @@ package rulelist_test import ( "context" - "io" "net/http" - "net/http/httptest" "net/url" "os" "path/filepath" @@ -20,23 +18,8 @@ func TestFilter_Refresh(t *testing.T) { cacheDir := t.TempDir() uid := rulelist.MustNewUID() - initialFile := filepath.Join(cacheDir, "initial.txt") - initialData := []byte( - testRuleTextTitle + - testRuleTextBlocked, - ) - writeErr := os.WriteFile(initialFile, initialData, 0o644) - require.NoError(t, writeErr) - - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - pt := testutil.PanicT{} - - _, err := io.WriteString(w, testRuleTextTitle+testRuleTextBlocked) - require.NoError(pt, err) - })) - - srvURL, urlErr := url.Parse(srv.URL) - require.NoError(t, urlErr) + const fltData = testRuleTextTitle + testRuleTextBlocked + fileURL, srvURL := newFilterLocations(t, cacheDir, fltData, fltData) testCases := []struct { url *url.URL @@ -56,7 +39,7 @@ func TestFilter_Refresh(t *testing.T) { name: "file", url: &url.URL{ Scheme: "file", - Path: initialFile, + Path: fileURL.Path, }, wantNewErrMsg: "", }, { diff --git a/internal/filtering/rulelist/rulelist.go b/internal/filtering/rulelist/rulelist.go index e0fd61b4..b9505e91 100644 --- a/internal/filtering/rulelist/rulelist.go +++ b/internal/filtering/rulelist/rulelist.go @@ -23,8 +23,28 @@ const DefaultMaxRuleListSize = 64 * datasize.MB // URLFilterID is a semantic type-alias for IDs used for working with package // urlfilter. +// +// TODO(a.garipov): Use everywhere in package filtering. type URLFilterID = int +// The IDs of built-in filter lists. +// +// NOTE: Do not change without the need for it and keep in sync with +// client/src/helpers/constants.js. +// +// TODO(a.garipov): Add type [URLFilterID] once it is used consistently in +// package filtering. +// +// TODO(d.kolyshev): Add URLFilterIDLegacyRewrite here and to the UI. +const ( + URLFilterIDCustom = 0 + URLFilterIDEtcHosts = -1 + URLFilterIDBlockedService = -2 + URLFilterIDParentalControl = -3 + URLFilterIDSafeBrowsing = -4 + URLFilterIDSafeSearch = -5 +) + // UID is the type for the unique IDs of filtering-rule lists. type UID uuid.UUID diff --git a/internal/filtering/rulelist/rulelist_test.go b/internal/filtering/rulelist/rulelist_test.go index dc79d503..78731f33 100644 --- a/internal/filtering/rulelist/rulelist_test.go +++ b/internal/filtering/rulelist/rulelist_test.go @@ -1,11 +1,19 @@ package rulelist_test import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "sync/atomic" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -35,3 +43,70 @@ const ( // See https://github.com/AdguardTeam/AdGuardHome/issues/6003. testRuleTextCosmetic = "||cosmetic.example## :has-text(/\u200c/i)\n" ) + +// urlFilterIDCounter is the atomic integer used to create unique filter IDs. +var urlFilterIDCounter = &atomic.Int32{} + +// newURLFilterID returns a new unique URLFilterID. +func newURLFilterID() (id rulelist.URLFilterID) { + return rulelist.URLFilterID(urlFilterIDCounter.Add(1)) +} + +// newFilter is a helper for creating new filters in tests. It does not +// register the closing of the filter using t.Cleanup; callers must do that +// either directly or by using the filter in an engine. +func newFilter(t testing.TB, u *url.URL, name string) (f *rulelist.Filter) { + t.Helper() + + f, err := rulelist.NewFilter(&rulelist.FilterConfig{ + URL: u, + Name: name, + UID: rulelist.MustNewUID(), + URLFilterID: newURLFilterID(), + Enabled: true, + }) + require.NoError(t, err) + + return f +} + +// newFilterLocations is a test helper that sets up both the filtering-rule list +// file and the HTTP-server. It also registers file removal and server stopping +// using t.Cleanup. +func newFilterLocations( + t testing.TB, + cacheDir string, + fileData string, + httpData string, +) (fileURL, srvURL *url.URL) { + filePath := filepath.Join(cacheDir, "initial.txt") + err := os.WriteFile(filePath, []byte(fileData), 0o644) + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return os.Remove(filePath) + }) + + fileURL = &url.URL{ + Scheme: "file", + Path: filePath, + } + + srv := newStringHTTPServer(httpData) + t.Cleanup(srv.Close) + + srvURL, err = url.Parse(srv.URL) + require.NoError(t, err) + + return fileURL, srvURL +} + +// newStringHTTPServer returns a new HTTP server that serves s. +func newStringHTTPServer(s string) (srv *httptest.Server) { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + pt := testutil.PanicT{} + + _, err := io.WriteString(w, s) + require.NoError(pt, err) + })) +} diff --git a/internal/filtering/rulelist/textengine.go b/internal/filtering/rulelist/textengine.go new file mode 100644 index 00000000..4b5e8ce8 --- /dev/null +++ b/internal/filtering/rulelist/textengine.go @@ -0,0 +1,98 @@ +package rulelist + +import ( + "fmt" + "strings" + "sync" + + "github.com/AdguardTeam/urlfilter" + "github.com/AdguardTeam/urlfilter/filterlist" +) + +// TextEngine is a single DNS filter based on a list of rules in text form. +type TextEngine struct { + // mu protects engine and storage. + mu *sync.RWMutex + + // engine is the filtering engine. + engine *urlfilter.DNSEngine + + // storage is the filtering-rule storage. It is saved here to close it. + storage *filterlist.RuleStorage + + // name is the human-readable name of the engine, like "custom". + name string +} + +// TextEngineConfig is the configuration for a rule-list filtering engine +// created from a filtering rule text. +type TextEngineConfig struct { + // Name is the human-readable name of this engine, like "allowed", + // "blocked", or "custom". + Name string + + // Rules is the text of the filtering rules for this engine. + Rules []string + + // ID is the ID to use inside a URL-filter engine. + ID URLFilterID +} + +// NewTextEngine returns a new rule-list filtering engine that uses rules +// directly. The engine is ready to use and should not be refreshed. +func NewTextEngine(c *TextEngineConfig) (e *TextEngine, err error) { + text := strings.Join(c.Rules, "\n") + storage, err := filterlist.NewRuleStorage([]filterlist.RuleList{ + &filterlist.StringRuleList{ + RulesText: text, + ID: c.ID, + IgnoreCosmetic: true, + }, + }) + if err != nil { + return nil, fmt.Errorf("creating rule storage: %w", err) + } + + engine := urlfilter.NewDNSEngine(storage) + + return &TextEngine{ + mu: &sync.RWMutex{}, + engine: engine, + storage: storage, + name: c.Name, + }, nil +} + +// FilterRequest returns the result of filtering req using the DNS filtering +// engine. +func (e *TextEngine) FilterRequest( + req *urlfilter.DNSRequest, +) (res *urlfilter.DNSResult, hasMatched bool) { + var engine *urlfilter.DNSEngine + + func() { + e.mu.RLock() + defer e.mu.RUnlock() + + engine = e.engine + }() + + return engine.MatchRequest(req) +} + +// Close closes the underlying rule list engine as well as the rule lists. +func (e *TextEngine) Close() (err error) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.storage == nil { + return nil + } + + err = e.storage.Close() + if err != nil { + return fmt.Errorf("closing text engine %q: %w", e.name, err) + } + + return nil +} diff --git a/internal/filtering/rulelist/textengine_test.go b/internal/filtering/rulelist/textengine_test.go new file mode 100644 index 00000000..129d01c7 --- /dev/null +++ b/internal/filtering/rulelist/textengine_test.go @@ -0,0 +1,40 @@ +package rulelist_test + +import ( + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/urlfilter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTextEngine(t *testing.T) { + eng, err := rulelist.NewTextEngine(&rulelist.TextEngineConfig{ + Name: "RulesEngine", + Rules: []string{ + testRuleTextTitle, + testRuleTextBlocked, + }, + ID: testURLFilterID, + }) + require.NoError(t, err) + require.NotNil(t, eng) + testutil.CleanupAndRequireSuccess(t, eng.Close) + + fltReq := &urlfilter.DNSRequest{ + Hostname: "blocked.example", + Answer: false, + DNSType: dns.TypeA, + } + + fltRes, hasMatched := eng.FilterRequest(fltReq) + assert.True(t, hasMatched) + + require.NotNil(t, fltRes) + require.NotNil(t, fltRes.NetworkRule) + + assert.Equal(t, fltRes.NetworkRule.FilterListID, testURLFilterID) +} diff --git a/internal/filtering/safesearch/safesearch.go b/internal/filtering/safesearch/safesearch.go index 47e66ac6..50c0d187 100644 --- a/internal/filtering/safesearch/safesearch.go +++ b/internal/filtering/safesearch/safesearch.go @@ -14,6 +14,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter" @@ -98,7 +99,7 @@ func NewDefault( cacheTTL: cacheTTL, } - err = ss.resetEngine(filtering.SafeSearchListID, conf) + err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err @@ -234,7 +235,7 @@ func (ss *Default) newResult( ) (res *filtering.Result, err error) { res = &filtering.Result{ Rules: []*filtering.ResultRule{{ - FilterListID: filtering.SafeSearchListID, + FilterListID: rulelist.URLFilterIDSafeSearch, }}, Reason: filtering.FilteredSafeSearch, IsFiltered: true, @@ -368,7 +369,7 @@ func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) { ss.mu.Lock() defer ss.mu.Unlock() - err = ss.resetEngine(filtering.SafeSearchListID, conf) + err = ss.resetEngine(rulelist.URLFilterIDSafeSearch, conf) if err != nil { // Don't wrap the error, because it's informative enough as is. return err diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index 16b720d1..1ad4bf11 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" @@ -69,7 +70,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) { require.Len(t, res.Rules, 1) assert.Equal(t, yandexIP, res.Rules[0].IP) - assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) + assert.EqualValues(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) } } @@ -89,7 +90,7 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) { require.Len(t, res.Rules, 1) assert.Empty(t, res.Rules[0].IP) - assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) + assert.EqualValues(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) } func TestDefault_CheckHost_google(t *testing.T) { @@ -128,7 +129,7 @@ func TestDefault_CheckHost_google(t *testing.T) { require.Len(t, res.Rules, 1) assert.Equal(t, wantIP, res.Rules[0].IP) - assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) + assert.EqualValues(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) }) } } @@ -180,7 +181,7 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { require.Len(t, res.Rules, 1) assert.Empty(t, res.Rules[0].IP) - assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) + assert.EqualValues(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) } func TestDefault_Update(t *testing.T) {