Fix race conditions found by go's race detector
This commit is contained in:
parent
2c33905a79
commit
2244c21b76
|
@ -27,7 +27,7 @@ type configuration struct {
|
||||||
Filters []filter `yaml:"filters"`
|
Filters []filter `yaml:"filters"`
|
||||||
UserRules []string `yaml:"user_rules"`
|
UserRules []string `yaml:"user_rules"`
|
||||||
|
|
||||||
sync.Mutex `yaml:"-"`
|
sync.RWMutex `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type coreDNSConfig struct {
|
type coreDNSConfig struct {
|
||||||
|
|
|
@ -789,10 +789,11 @@ func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
"enabled": config.CoreDNS.FilteringEnabled,
|
"enabled": config.CoreDNS.FilteringEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.RLock()
|
||||||
data["filters"] = config.Filters
|
data["filters"] = config.Filters
|
||||||
data["user_rules"] = config.UserRules
|
data["user_rules"] = config.UserRules
|
||||||
|
|
||||||
json, err := json.Marshal(data)
|
json, err := json.Marshal(data)
|
||||||
|
config.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||||
|
@ -1122,7 +1123,6 @@ func runFilterRefreshers() {
|
||||||
func refreshFiltersIfNeccessary() int {
|
func refreshFiltersIfNeccessary() int {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
config.Lock()
|
config.Lock()
|
||||||
defer config.Unlock()
|
|
||||||
|
|
||||||
// deduplicate
|
// deduplicate
|
||||||
// TODO: move it somewhere else
|
// TODO: move it somewhere else
|
||||||
|
@ -1154,6 +1154,7 @@ func refreshFiltersIfNeccessary() int {
|
||||||
updateCount++
|
updateCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
config.Unlock()
|
||||||
|
|
||||||
if updateCount > 0 {
|
if updateCount > 0 {
|
||||||
err := writeFilterFile()
|
err := writeFilterFile()
|
||||||
|
@ -1237,6 +1238,7 @@ func writeFilterFile() error {
|
||||||
log.Printf("Writing filter file: %s", filterpath)
|
log.Printf("Writing filter file: %s", filterpath)
|
||||||
// TODO: check if file contents have modified
|
// TODO: check if file contents have modified
|
||||||
data := []byte{}
|
data := []byte{}
|
||||||
|
config.RLock()
|
||||||
filters := config.Filters
|
filters := config.Filters
|
||||||
for _, filter := range filters {
|
for _, filter := range filters {
|
||||||
if !filter.Enabled {
|
if !filter.Enabled {
|
||||||
|
@ -1249,6 +1251,7 @@ func writeFilterFile() error {
|
||||||
data = append(data, []byte(rule)...)
|
data = append(data, []byte(rule)...)
|
||||||
data = append(data, '\n')
|
data = append(data, '\n')
|
||||||
}
|
}
|
||||||
|
config.RUnlock()
|
||||||
err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
|
err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Couldn't write filter file: %s", err)
|
log.Printf("Couldn't write filter file: %s", err)
|
||||||
|
|
|
@ -55,6 +55,8 @@ type plug struct {
|
||||||
ParentalBlockHost string
|
ParentalBlockHost string
|
||||||
QueryLogEnabled bool
|
QueryLogEnabled bool
|
||||||
BlockedTTL uint32 // in seconds, default 3600
|
BlockedTTL uint32 // in seconds, default 3600
|
||||||
|
|
||||||
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultPlugin = plug{
|
var defaultPlugin = plug{
|
||||||
|
@ -246,17 +248,21 @@ func (p *plug) parseEtcHosts(text string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plug) onShutdown() error {
|
func (p *plug) onShutdown() error {
|
||||||
|
p.Lock()
|
||||||
p.d.Destroy()
|
p.d.Destroy()
|
||||||
p.d = nil
|
p.d = nil
|
||||||
|
p.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plug) onFinalShutdown() error {
|
func (p *plug) onFinalShutdown() error {
|
||||||
|
logBufferLock.Lock()
|
||||||
err := flushToFile(logBuffer)
|
err := flushToFile(logBuffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to flush to file: %s", err)
|
log.Printf("failed to flush to file: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
logBufferLock.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,9 +299,11 @@ func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
|
func (p *plug) doStats(ch interface{}, doFunc statsFunc) {
|
||||||
|
p.RLock()
|
||||||
stats := p.d.GetStats()
|
stats := p.d.GetStats()
|
||||||
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
|
doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing)
|
||||||
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
|
doStatsLookup(ch, doFunc, "parental", &stats.Parental)
|
||||||
|
p.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Describe is called by prometheus handler to know stat types
|
// Describe is called by prometheus handler to know stat types
|
||||||
|
@ -365,12 +373,12 @@ func (p *plug) genSOA(r *dns.Msg) []dns.RR {
|
||||||
}
|
}
|
||||||
Ns := "fake-for-negative-caching.adguard.com."
|
Ns := "fake-for-negative-caching.adguard.com."
|
||||||
|
|
||||||
soa := defaultSOA
|
soa := *defaultSOA
|
||||||
soa.Hdr = header
|
soa.Hdr = header
|
||||||
soa.Mbox = Mbox
|
soa.Mbox = Mbox
|
||||||
soa.Ns = Ns
|
soa.Ns = Ns
|
||||||
soa.Serial = uint32(time.Now().Unix())
|
soa.Serial = 100500 // faster than uint32(time.Now().Unix())
|
||||||
return []dns.RR{soa}
|
return []dns.RR{&soa}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
@ -397,13 +405,17 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
|
||||||
for _, question := range r.Question {
|
for _, question := range r.Question {
|
||||||
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
host := strings.ToLower(strings.TrimSuffix(question.Name, "."))
|
||||||
// is it a safesearch domain?
|
// is it a safesearch domain?
|
||||||
|
p.RLock()
|
||||||
if val, ok := p.d.SafeSearchDomain(host); ok {
|
if val, ok := p.d.SafeSearchDomain(host); ok {
|
||||||
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
p.RUnlock()
|
||||||
return rcode, dnsfilter.Result{}, err
|
return rcode, dnsfilter.Result{}, err
|
||||||
}
|
}
|
||||||
|
p.RUnlock()
|
||||||
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err
|
||||||
}
|
}
|
||||||
|
p.RUnlock()
|
||||||
|
|
||||||
// is it in hosts?
|
// is it in hosts?
|
||||||
if val, ok := p.hosts[host]; ok {
|
if val, ok := p.hosts[host]; ok {
|
||||||
|
@ -425,11 +437,14 @@ func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dn
|
||||||
}
|
}
|
||||||
|
|
||||||
// needs to be filtered instead
|
// needs to be filtered instead
|
||||||
|
p.RLock()
|
||||||
result, err := p.d.CheckHost(host)
|
result, err := p.d.CheckHost(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("plugin/dnsfilter: %s\n", err)
|
log.Printf("plugin/dnsfilter: %s\n", err)
|
||||||
|
p.RUnlock()
|
||||||
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
|
return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err)
|
||||||
}
|
}
|
||||||
|
p.RUnlock()
|
||||||
|
|
||||||
if result.IsFiltered {
|
if result.IsFiltered {
|
||||||
switch result.Reason {
|
switch result.Reason {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdguardDNS/dnsfilter"
|
"github.com/AdguardTeam/AdguardDNS/dnsfilter"
|
||||||
|
@ -23,6 +24,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
logBufferLock sync.RWMutex
|
||||||
logBuffer []logEntry
|
logBuffer []logEntry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,11 +67,13 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||||
}
|
}
|
||||||
var flushBuffer []logEntry
|
var flushBuffer []logEntry
|
||||||
|
|
||||||
|
logBufferLock.Lock()
|
||||||
logBuffer = append(logBuffer, entry)
|
logBuffer = append(logBuffer, entry)
|
||||||
if len(logBuffer) >= logBufferCap {
|
if len(logBuffer) >= logBufferCap {
|
||||||
flushBuffer = logBuffer
|
flushBuffer = logBuffer
|
||||||
logBuffer = nil
|
logBuffer = nil
|
||||||
}
|
}
|
||||||
|
logBufferLock.Unlock()
|
||||||
if len(flushBuffer) > 0 {
|
if len(flushBuffer) > 0 {
|
||||||
// write to file
|
// write to file
|
||||||
// do it in separate goroutine -- we are stalling DNS response this whole time
|
// do it in separate goroutine -- we are stalling DNS response this whole time
|
||||||
|
@ -81,7 +85,9 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela
|
||||||
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||||
// TODO: fetch values from disk if len(logBuffer) < queryLogSize
|
// TODO: fetch values from disk if len(logBuffer) < queryLogSize
|
||||||
// TODO: cache output
|
// TODO: cache output
|
||||||
|
logBufferLock.RLock()
|
||||||
values := logBuffer
|
values := logBuffer
|
||||||
|
logBufferLock.RUnlock()
|
||||||
var data = []map[string]interface{}{}
|
var data = []map[string]interface{}{}
|
||||||
for _, entry := range values {
|
for _, entry := range values {
|
||||||
var q *dns.Msg
|
var q *dns.Msg
|
||||||
|
|
10
stats.go
10
stats.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -57,6 +58,7 @@ type stats struct {
|
||||||
PerDay periodicStats
|
PerDay periodicStats
|
||||||
|
|
||||||
LastSeen statsEntry
|
LastSeen statsEntry
|
||||||
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var statistics stats
|
var statistics stats
|
||||||
|
@ -71,10 +73,12 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func purgeStats() {
|
func purgeStats() {
|
||||||
|
statistics.Lock()
|
||||||
initPeriodicStats(&statistics.PerSecond)
|
initPeriodicStats(&statistics.PerSecond)
|
||||||
initPeriodicStats(&statistics.PerMinute)
|
initPeriodicStats(&statistics.PerMinute)
|
||||||
initPeriodicStats(&statistics.PerHour)
|
initPeriodicStats(&statistics.PerHour)
|
||||||
initPeriodicStats(&statistics.PerDay)
|
initPeriodicStats(&statistics.PerDay)
|
||||||
|
statistics.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func runStatsCollectors() {
|
func runStatsCollectors() {
|
||||||
|
@ -121,10 +125,12 @@ func statsRotate(periodic *periodicStats, now time.Time, rotations int64) {
|
||||||
// called every second, accumulates stats for each second, minute, hour and day
|
// called every second, accumulates stats for each second, minute, hour and day
|
||||||
func collectStats() {
|
func collectStats() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
statistics.Lock()
|
||||||
statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second))
|
statsRotate(&statistics.PerSecond, now, int64(now.Sub(statistics.PerSecond.LastRotate)/time.Second))
|
||||||
statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute))
|
statsRotate(&statistics.PerMinute, now, int64(now.Sub(statistics.PerMinute.LastRotate)/time.Minute))
|
||||||
statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour))
|
statsRotate(&statistics.PerHour, now, int64(now.Sub(statistics.PerHour.LastRotate)/time.Hour))
|
||||||
statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24))
|
statsRotate(&statistics.PerDay, now, int64(now.Sub(statistics.PerDay.LastRotate)/time.Hour/24))
|
||||||
|
statistics.Unlock()
|
||||||
|
|
||||||
// grab HTTP from prometheus
|
// grab HTTP from prometheus
|
||||||
resp, err := client.Get("http://127.0.0.1:9153/metrics")
|
resp, err := client.Get("http://127.0.0.1:9153/metrics")
|
||||||
|
@ -191,6 +197,7 @@ func collectStats() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate delta
|
// calculate delta
|
||||||
|
statistics.Lock()
|
||||||
delta := calcDelta(entry, statistics.LastSeen)
|
delta := calcDelta(entry, statistics.LastSeen)
|
||||||
|
|
||||||
// apply delta to second/minute/hour/day
|
// apply delta to second/minute/hour/day
|
||||||
|
@ -201,6 +208,7 @@ func collectStats() {
|
||||||
|
|
||||||
// save last seen
|
// save last seen
|
||||||
statistics.LastSeen = entry
|
statistics.LastSeen = entry
|
||||||
|
statistics.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func calcDelta(current, seen statsEntry) statsEntry {
|
func calcDelta(current, seen statsEntry) statsEntry {
|
||||||
|
@ -245,7 +253,9 @@ func loadStats() error {
|
||||||
func writeStats() error {
|
func writeStats() error {
|
||||||
statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
|
statsFile := filepath.Join(config.ourBinaryDir, "stats.json")
|
||||||
log.Printf("Writing JSON file: %s", statsFile)
|
log.Printf("Writing JSON file: %s", statsFile)
|
||||||
|
statistics.RLock()
|
||||||
json, err := json.MarshalIndent(statistics, "", " ")
|
json, err := json.MarshalIndent(statistics, "", " ")
|
||||||
|
statistics.RUnlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Couldn't generate JSON: %s", err)
|
log.Printf("Couldn't generate JSON: %s", err)
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue