diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index fd5953e3..977691f9 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -22,6 +22,7 @@ import ( "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/mathutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" @@ -287,12 +288,7 @@ func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons // SetEnabled sets the status of the *DNSFilter. func (d *DNSFilter) SetEnabled(enabled bool) { - var i int32 - if enabled { - i = 1 - } - - atomic.StoreUint32(&d.enabled, uint32(i)) + atomic.StoreUint32(&d.enabled, mathutil.BoolToNumber[uint32](enabled)) } // GetConfig - get configuration diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index c6b7c34c..672ca8a3 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -10,6 +10,7 @@ import ( "net/http" "sort" "strings" + "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" @@ -369,13 +370,35 @@ func (d *DNSFilter) checkParental( return check(sctx, res, d.parentalUpstream) } +// setProtectedBool sets the value of a boolean pointer under a lock. l must +// protect the value under ptr. +// +// TODO(e.burkov): Make it generic? +func setProtectedBool(mu *sync.RWMutex, ptr *bool, val bool) { + mu.Lock() + defer mu.Unlock() + + *ptr = val +} + +// protectedBool gets the value of a boolean pointer under a read lock. l must +// protect the value under ptr. +// +// TODO(e.burkov): Make it generic? +func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) { + mu.RLock() + defer mu.RUnlock() + + return *ptr +} + func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeBrowsingEnabled = true + setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, true) d.Config.ConfigModified() } func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeBrowsingEnabled = false + setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, false) d.Config.ConfigModified() } @@ -383,19 +406,19 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ resp := &struct { Enabled bool `json:"enabled"` }{ - Enabled: d.Config.SafeBrowsingEnabled, + Enabled: protectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled), } _ = aghhttp.WriteJSONResponse(w, r, resp) } func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { - d.Config.ParentalEnabled = true + setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, true) d.Config.ConfigModified() } func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) { - d.Config.ParentalEnabled = false + setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, false) d.Config.ConfigModified() } @@ -403,7 +426,7 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) resp := &struct { Enabled bool `json:"enabled"` }{ - Enabled: d.Config.ParentalEnabled, + Enabled: protectedBool(&d.confLock, &d.Config.ParentalEnabled), } _ = aghhttp.WriteJSONResponse(w, r, resp) diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index 8b3dcb9b..f7661dd6 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -135,12 +135,12 @@ func (d *DNSFilter) checkSafeSearch( } func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeSearchEnabled = true + setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, true) d.Config.ConfigModified() } func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeSearchEnabled = false + setProtectedBool(&d.confLock, &d.Config.SafeSearchEnabled, false) d.Config.ConfigModified() } @@ -148,7 +148,7 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques resp := &struct { Enabled bool `json:"enabled"` }{ - Enabled: d.Config.SafeSearchEnabled, + Enabled: protectedBool(&d.confLock, &d.Config.SafeSearchEnabled), } _ = aghhttp.WriteJSONResponse(w, r, resp)