Pull request 1790: 5624-fix-filter-add

Merge in DNS/adguard-home from 5624-fix-filter-add to master

Updates #5624.

Squashed commit of the following:

commit 211100409d2c711a5ccb5aeafbe16115388aaff7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 29 20:46:48 2023 +0500

    filtering: imp names

commit b42ed3748e5d4310a9f8a6a37cee5bf56104917f
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Mar 29 17:41:49 2023 +0500

    filtering: imp logging, lock properly
This commit is contained in:
Eugene Burkov 2023-03-29 19:09:54 +03:00
parent c576d5059e
commit da9008aba3
2 changed files with 104 additions and 63 deletions

View File

@ -176,13 +176,16 @@ func (d *DNSFilter) filterExistsLocked(url string) (ok bool) {
// Add a filter // Add a filter
// Return FALSE if a filter with this URL exists // Return FALSE if a filter with this URL exists
func (d *DNSFilter) filterAdd(flt FilterYAML) bool { func (d *DNSFilter) filterAdd(flt FilterYAML) (err error) {
// Defer annotating to unlock sooner.
defer func() { err = errors.Annotate(err, "adding filter: %w") }()
d.filtersMu.Lock() d.filtersMu.Lock()
defer d.filtersMu.Unlock() defer d.filtersMu.Unlock()
// Check for duplicates // Check for duplicates.
if d.filterExistsLocked(flt.URL) { if d.filterExistsLocked(flt.URL) {
return false return errFilterExists
} }
if flt.white { if flt.white {
@ -190,7 +193,8 @@ func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
} else { } else {
d.Filters = append(d.Filters, flt) d.Filters = append(d.Filters, flt)
} }
return true
return nil
} }
// Load filters from the disk // Load filters from the disk
@ -238,6 +242,7 @@ func updateUniqueFilterID(filters []FilterYAML) {
} }
} }
// TODO(e.burkov): Improve this inexhaustible source of races.
func assignUniqueFilterID() int64 { func assignUniqueFilterID() int64 {
value := nextFilterID value := nextFilterID
nextFilterID++ nextFilterID++
@ -343,29 +348,31 @@ func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int,
} }
updateCount := 0 updateCount := 0
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated := updateFlags[i] updated := updateFlags[i]
d.filtersMu.Lock()
for k := range *filters { for k := range *filters {
f := &(*filters)[k] f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL { if f.ID != uf.ID || f.URL != uf.URL {
continue continue
} }
f.LastUpdated = uf.LastUpdated f.LastUpdated = uf.LastUpdated
if !updated { if !updated {
continue continue
} }
log.Info("Updated filter #%d. Rules: %d -> %d", log.Info("Updated filter #%d. Rules: %d -> %d", f.ID, f.RulesCount, uf.RulesCount)
f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name f.Name = uf.Name
f.RulesCount = uf.RulesCount f.RulesCount = uf.RulesCount
f.checksum = uf.checksum f.checksum = uf.checksum
updateCount++ updateCount++
} }
d.filtersMu.Unlock()
} }
return updateCount, updateFilters, updateFlags, false return updateCount, updateFilters, updateFlags, false
@ -421,11 +428,16 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
if !updated { if !updated {
continue continue
} }
_ = os.Remove(uf.Path(d.DataDir) + ".old")
p := uf.Path(d.DataDir)
err := os.Remove(p + ".old")
if err != nil {
log.Debug("filtering: removing old filter file %q: %s", p, err)
}
} }
} }
log.Debug("filtering: update finished") log.Debug("filtering: update finished: %d lists updated", updNum)
return updNum, false return updNum, false
} }
@ -467,8 +479,8 @@ func scanLinesWithBreak(data []byte, atEOF bool) (advance int, token []byte, err
} }
// parseFilter copies filter's content from src to dst and returns the number of // parseFilter copies filter's content from src to dst and returns the number of
// rules, name, number of bytes written, checksum, and title of the parsed list. // rules, number of bytes written, checksum, and title of the parsed list. dst
// dst must not be nil. // must not be nil.
func (d *DNSFilter) parseFilter( func (d *DNSFilter) parseFilter(
src io.Reader, src io.Reader,
dst io.Writer, dst io.Writer,
@ -550,14 +562,18 @@ func isHTML(line string) (ok bool) {
return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype") return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype")
} }
// Perform upgrade on a filter and update LastUpdated value // update refreshes filter's content and a/mtimes of it's file.
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) { func (d *DNSFilter) update(filter *FilterYAML) (b bool, err error) {
b, err := d.updateIntl(filter) b, err = d.updateIntl(filter)
filter.LastUpdated = time.Now() filter.LastUpdated = time.Now()
if !b { if !b {
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated) chErr := os.Chtimes(
if e != nil { filter.Path(d.DataDir),
log.Error("os.Chtimes(): %v", e) filter.LastUpdated,
filter.LastUpdated,
)
if chErr != nil {
log.Error("os.Chtimes(): %v", chErr)
} }
} }
@ -591,11 +607,13 @@ func (d *DNSFilter) finalizeUpdate(
return os.Remove(tmpFileName) return os.Remove(tmpFileName)
} }
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir)) fltPath := flt.Path(d.DataDir)
log.Printf("saving contents of filter #%d into %s", flt.ID, fltPath)
// Don't use renamio or maybe packages, since those will require loading the // Don't use renamio or maybe packages, since those will require loading the
// whole filter content to the memory on Windows. // whole filter content to the memory on Windows.
err = os.Rename(tmpFileName, flt.Path(d.DataDir)) err = os.Rename(tmpFileName, fltPath)
if err != nil { if err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName)) return errors.WithDeferred(err, os.Remove(tmpFileName))
} }
@ -620,10 +638,14 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, err return false, err
} }
defer func() { defer func() {
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)) finErr := d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)
if ok && err == nil { if ok && finErr == nil {
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum) log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
return
} }
err = errors.WithDeferred(err, finErr)
}() }()
// Change the default 0o600 permission to something more acceptable by end // Change the default 0o600 permission to something more acceptable by end
@ -634,7 +656,7 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, fmt.Errorf("changing file mode: %w", err) return false, fmt.Errorf("changing file mode: %w", err)
} }
var rc io.ReadCloser var r io.Reader
if !filepath.IsAbs(flt.URL) { if !filepath.IsAbs(flt.URL) {
var resp *http.Response var resp *http.Response
resp, err = d.HTTPClient.Get(flt.URL) resp, err = d.HTTPClient.Get(flt.URL)
@ -651,16 +673,19 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK) return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
} }
rc = resp.Body r = resp.Body
} else { } else {
rc, err = os.Open(flt.URL) var f *os.File
f, err = os.Open(flt.URL)
if err != nil { if err != nil {
return false, fmt.Errorf("open file: %w", err) return false, fmt.Errorf("open file: %w", err)
} }
defer func() { err = errors.WithDeferred(err, rc.Close()) }() defer func() { err = errors.WithDeferred(err, f.Close()) }()
r = f
} }
rnum, n, cs, name, err = d.parseFilter(rc, tmpFile) rnum, n, cs, name, err = d.parseFilter(r, tmpFile)
return cs != flt.checksum && err == nil, err return cs != flt.checksum && err == nil, err
} }
@ -705,10 +730,11 @@ func (d *DNSFilter) EnableFilters(async bool) {
} }
func (d *DNSFilter) enableFiltersLocked(async bool) { func (d *DNSFilter) enableFiltersLocked(async bool) {
filters := []Filter{{ filters := make([]Filter, 1, len(d.Filters)+len(d.WhitelistFilters)+1)
filters[0] = Filter{
ID: CustomListID, ID: CustomListID,
Data: []byte(strings.Join(d.UserRules, "\n")), Data: []byte(strings.Join(d.UserRules, "\n")),
}} }
for _, filter := range d.Filters { for _, filter := range d.Filters {
if !filter.Enabled { if !filter.Enabled {

View File

@ -14,26 +14,33 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/exp/slices"
) )
// validateFilterURL validates the filter list URL or file name. // validateFilterURL validates the filter list URL or file name.
func validateFilterURL(urlStr string) (err error) { func validateFilterURL(urlStr string) (err error) {
defer func() { err = errors.Annotate(err, "checking filter: %w") }()
if filepath.IsAbs(urlStr) { if filepath.IsAbs(urlStr) {
_, err = os.Stat(urlStr) _, err = os.Stat(urlStr)
if err != nil { if err != nil {
return fmt.Errorf("checking filter file: %w", err) // Don't wrap the error since it's informative enough as is.
return err
} }
return nil return nil
} }
url, err := url.ParseRequestURI(urlStr) u, err := url.ParseRequestURI(urlStr)
if err != nil { if err != nil {
return fmt.Errorf("checking filter url: %w", err) // Don't wrap the error since it's informative enough as is.
} return err
} else if s := u.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS { return &url.Error{
return fmt.Errorf("checking filter url: invalid scheme %q", s) Op: "Check scheme",
URL: urlStr,
Err: fmt.Errorf("only %v allowed", []string{aghhttp.SchemeHTTP, aghhttp.SchemeHTTPS}),
}
} }
return nil return nil
@ -63,7 +70,8 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Check for duplicates // Check for duplicates
if d.filterExists(fj.URL) { if d.filterExists(fj.URL) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) err = errFilterExists
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", fj.URL, err)
return return
} }
@ -99,7 +107,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
r, r,
w, w,
http.StatusBadRequest, http.StatusBadRequest,
"Filter at the url %s is invalid (maybe it points to blank page?)", "Filter with URL %q is invalid (maybe it points to blank page?)",
filt.URL, filt.URL,
) )
@ -108,8 +116,9 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// URL is assumed valid so append it to filters, update config, write new // URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines. // file and reload it to engines.
if !d.filterAdd(filt) { err = d.filterAdd(filt)
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", filt.URL, err)
return return
} }
@ -137,31 +146,38 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
return return
} }
d.filtersMu.Lock()
filters := &d.Filters
if req.Whitelist {
filters = &d.WhitelistFilters
}
var deleted FilterYAML var deleted FilterYAML
var newFilters []FilterYAML func() {
for _, flt := range *filters { d.filtersMu.Lock()
if flt.URL != req.URL { defer d.filtersMu.Unlock()
newFilters = append(newFilters, flt)
continue filters := &d.Filters
if req.Whitelist {
filters = &d.WhitelistFilters
} }
deleted = flt delIdx := slices.IndexFunc(*filters, func(flt FilterYAML) bool {
path := flt.Path(d.DataDir) return flt.URL == req.URL
err = os.Rename(path, path+".old") })
if delIdx == -1 {
log.Error("deleting filter with url %q: %s", req.URL, errFilterNotExist)
return
}
deleted = (*filters)[delIdx]
p := deleted.Path(d.DataDir)
err = os.Rename(p, p+".old")
if err != nil { if err != nil {
log.Error("deleting filter %q: %s", path, err) log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err)
}
}
*filters = newFilters return
d.filtersMu.Unlock() }
*filters = slices.Delete(*filters, delIdx, delIdx+1)
log.Info("deleted filter %d", deleted.ID)
}()
d.ConfigModified() d.ConfigModified()
d.EnableFilters(true) d.EnableFilters(true)
@ -258,10 +274,6 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
type Req struct { type Req struct {
White bool `json:"whitelist"` White bool `json:"whitelist"`
} }
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error var err error
req := Req{} req := Req{}
@ -273,6 +285,9 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
} }
var ok bool var ok bool
resp := struct {
Updated int `json:"updated"`
}{}
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true) resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
if !ok { if !ok {
aghhttp.Error( aghhttp.Error(