diff --git a/home/filter.go b/home/filter.go index 6b6af9f7..494d98e5 100644 --- a/home/filter.go +++ b/home/filter.go @@ -24,12 +24,6 @@ var ( nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID ) -// type FilteringConf struct { -// BlockLists []filter -// AllowLists []filter -// UserRules []string -// } - // Filtering - module object type Filtering struct { // conf FilteringConf @@ -447,8 +441,9 @@ func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) { } // Allows printable UTF-8 text with CR, LF, TAB characters -func isPrintableText(data []byte) bool { - for _, c := range data { +func isPrintableText(data []byte, len int) bool { + for i := 0; i < len; i++ { + c := data[i] if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' { continue } @@ -549,7 +544,7 @@ func (f *Filtering) updateIntl(filter *filter) (bool, error) { firstChunkLen += copied if firstChunkLen == len(firstChunk) || err == io.EOF { - if !isPrintableText(firstChunk) { + if !isPrintableText(firstChunk, firstChunkLen) { return false, fmt.Errorf("data contains non-printable characters") } diff --git a/home/filter_test.go b/home/filter_test.go index 7449f037..e4f10f1f 100644 --- a/home/filter_test.go +++ b/home/filter_test.go @@ -1,6 +1,8 @@ package home import ( + "fmt" + "net" "net/http" "os" "testing" @@ -9,7 +11,28 @@ import ( "github.com/stretchr/testify/assert" ) +func testStartFilterListener() net.Listener { + http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) { + content := `||example.org^$third-party +||example.com^$third-party +0.0.0.0 example.com +` + _, _ = w.Write([]byte(content)) + }) + + listener, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + go func() { _ = http.Serve(listener, nil) }() + return listener +} + func TestFilters(t *testing.T) { + l := testStartFilterListener() + defer func() { _ = l.Close() }() + dir := prepareTestDir() defer func() { _ = os.RemoveAll(dir) }() Context = homeContext{} @@ -20,13 +43,14 @@ func TestFilters(t *testing.T) { Context.filters.Init() f := filter{ - URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", + URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port), } // download ok, err := Context.filters.update(&f) assert.Equal(t, nil, err) assert.True(t, ok) + assert.Equal(t, 3, f.RulesCount) // refresh ok, err = Context.filters.update(&f)