cherry-pick: 4358 fix stats

Merge in DNS/adguard-home from 4358-fix-stats to master

Updates #4358.
Updates #4342.

Squashed commit of the following:

commit 5683cb304688ea639e5ba7f219a7bf12370211a4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 18:20:54 2022 +0300

    stats: rm races test

commit 63dd67650ed64eaf9685b955a4fdf3c0067a7f8c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 17:13:36 2022 +0300

    stats: try to imp test

commit 59a0f249fc00566872db62e362c87bc0c201b333
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 16:38:57 2022 +0300

    stats: fix nil ptr deref

commit 7fc3ff18a34a1d0e0fec3ca83a33f499ac752572
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 7 16:02:51 2022 +0300

    stats: fix races finally, imp tests

commit c63f5f4e7929819fe79b3a1e392f6b91cd630846
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 00:56:49 2022 +0300

    aghhttp: add register func

commit 61adc7f0e95279c1b7f4a0c0af5ab387ee461411
Merge: edbdb2d4 9b3adac1
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Aug 4 00:36:01 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit edbdb2d4c6a06dcbf8107a28c4c3a61ba394e907
Merge: a91e4d7a a481ff4c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 21:00:42 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit a91e4d7af13591eeef45cb7980d1ebc1650a5cb7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:46:19 2022 +0300

    stats: imp code, docs

commit c5f3814c5c1a734ca8ff6726cc9ffc1177a055cf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:16:13 2022 +0300

    all: log changes

commit 5e6caafc771dddc4c6be07c34658de359106fbe5
Merge: 091ba756 eb8e8166
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:09:10 2022 +0300

    Merge branch 'master' into 4358-fix-stats

commit 091ba75618d3689b9c04f05431283417c8cc52f9
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Aug 3 18:07:39 2022 +0300

    stats: imp docs, code

commit f2b2de77ce5f0448d6df9232a614a3710f1e2e8a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Aug 2 17:09:30 2022 +0300

    all: refactor stats & add mutexes

commit b3f11c455ceaa3738ec20eefc46f866ff36ed046
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 27 15:30:09 2022 +0300

    WIP
This commit is contained in:
Eugene Burkov 2022-08-04 19:05:28 +03:00 committed by Ainar Garipov
parent 56dc3eab02
commit 39b404be19
15 changed files with 434 additions and 300 deletions

View File

@ -9,6 +9,12 @@ import (
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// RegisterFunc is the function that sets the handler to handle the URL for the
// method.
//
// TODO(e.burkov, a.garipov): Get rid of it.
type RegisterFunc func(method, url string, handler http.HandlerFunc)
// OK responds with word OK. // OK responds with word OK.
func OK(w http.ResponseWriter) { func OK(w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil { if _, err := io.WriteString(w, "OK\n"); err != nil {

View File

@ -5,11 +5,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
) )
@ -126,7 +126,7 @@ type ServerConfig struct {
ConfigModified func() `yaml:"-"` ConfigModified func() `yaml:"-"`
// Register an HTTP handler // Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"` InterfaceName string `yaml:"interface_name"`

View File

@ -5,12 +5,12 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net" "net"
"net/http"
"os" "os"
"sort" "sort"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -191,7 +191,7 @@ type ServerConfig struct {
ConfigModified func() ConfigModified func()
// Register an HTTP handler // Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) HTTPRegister aghhttp.RegisterFunc
// ResolveClients signals if the RDNS should resolve clients' addresses. // ResolveClients signals if the RDNS should resolve clients' addresses.
ResolveClients bool ResolveClients bool

View File

@ -61,7 +61,7 @@ type Server struct {
dnsFilter *filtering.DNSFilter // DNS filter instance dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional) dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance queryLog querylog.QueryLog // Query log instance
stats stats.Stats stats stats.Interface
access *accessCtx access *accessCtx
// localDomainSuffix is the suffix used to detect internal hosts. It // localDomainSuffix is the suffix used to detect internal hosts. It
@ -107,7 +107,7 @@ const defaultLocalDomainSuffix = "lan"
// DNSCreateParams are parameters to create a new server. // DNSCreateParams are parameters to create a new server.
type DNSCreateParams struct { type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter DNSFilter *filtering.DNSFilter
Stats stats.Stats Stats stats.Interface
QueryLog querylog.QueryLog QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface DHCPServer dhcpd.ServerInterface
PrivateNets netutil.SubnetSet PrivateNets netutil.SubnetSet

View File

@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
type testStats struct { type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without // Stats is embedded here simply to make testStats a stats.Stats without
// actually implementing all methods. // actually implementing all methods.
stats.Stats stats.Interface
lastEntry stats.Entry lastEntry stats.Entry
} }

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net" "net"
"net/http"
"os" "os"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
@ -14,6 +13,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
@ -94,7 +94,7 @@ type Config struct {
ConfigModified func() `yaml:"-"` ConfigModified func() `yaml:"-"`
// Register an HTTP handler // Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter. // CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"` CustomResolver Resolver `yaml:"-"`

View File

@ -2,6 +2,7 @@ package home
import ( import (
"bytes" "bytes"
"encoding"
"fmt" "fmt"
"net" "net"
"sort" "sort"
@ -60,6 +61,33 @@ const (
ClientSourceHostsFile ClientSourceHostsFile
) )
var _ fmt.Stringer = clientSource(0)
// String returns a human-readable name of cs.
func (cs clientSource) String() (s string) {
switch cs {
case ClientSourceWHOIS:
return "WHOIS"
case ClientSourceARP:
return "ARP"
case ClientSourceRDNS:
return "rDNS"
case ClientSourceDHCP:
return "DHCP"
case ClientSourceHostsFile:
return "etc/hosts"
default:
return ""
}
}
var _ encoding.TextMarshaler = clientSource(0)
// MarshalText implements encoding.TextMarshaler for the clientSource.
func (cs clientSource) MarshalText() (text []byte, err error) {
return []byte(cs.String()), nil
}
// clientSourceConf is used to configure where the runtime clients will be // clientSourceConf is used to configure where the runtime clients will be
// obtained from. // obtained from.
type clientSourcesConf struct { type clientSourcesConf struct {
@ -397,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
c.Tags = stringutil.CloneSlice(c.Tags) c.Tags = stringutil.CloneSlice(c.Tags)
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices) c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
c.Upstreams = stringutil.CloneSlice(c.Upstreams) c.Upstreams = stringutil.CloneSlice(c.Upstreams)
return c, true return c, true
} }

View File

@ -47,9 +47,9 @@ type clientJSON struct {
type runtimeClientJSON struct { type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"` Name string `json:"name"`
Source string `json:"source"` Source clientSource `json:"source"`
IP net.IP `json:"ip"` IP net.IP `json:"ip"`
} }
type clientListJSON struct { type clientListJSON struct {
@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
cj := runtimeClientJSON{ cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo, WHOISInfo: rc.WHOISInfo,
Name: rc.Host, Name: rc.Host,
IP: ip, Source: rc.Source,
} IP: ip,
cj.Source = "etc/hosts"
switch rc.Source {
case ClientSourceDHCP:
cj.Source = "DHCP"
case ClientSourceRDNS:
cj.Source = "rDNS"
case ClientSourceARP:
cj.Source = "ARP"
case ClientSourceWHOIS:
cj.Source = "WHOIS"
} }
data.RuntimeClients = append(data.RuntimeClients, cj) data.RuntimeClients = append(data.RuntimeClients, cj)
@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data) e := json.NewEncoder(w).Encode(data)
if e != nil { if e != nil {
aghhttp.Error( aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
r,
w,
http.StatusInternalServerError,
"Failed to encode to json: %v",
e,
)
return return
} }
@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) { func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
rc, ok := clients.FindRuntimeClient(ip) rc, ok := clients.FindRuntimeClient(ip)
if !ok { if !ok {
// It is still possible that the IP used to be in the runtime // It is still possible that the IP used to be in the runtime clients
// clients list, but then the server was reloaded. So, check // list, but then the server was reloaded. So, check the DNS server's
// the DNS server's blocked IP list. // blocked IP list.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428. // See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@ -189,7 +189,7 @@ func registerControlHandlers() {
RegisterAuthHandlers() RegisterAuthHandlers()
} }
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) { func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" { if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method // "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler)) Context.mux.HandleFunc(url, postInstall(handler))

View File

@ -46,7 +46,7 @@ type homeContext struct {
// -- // --
clients clientsContainer // per-client-settings module clients clientsContainer // per-client-settings module
stats stats.Stats // statistics module stats stats.Interface // statistics module
queryLog querylog.QueryLog // query log module queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module rdns *RDNS // rDNS module

View File

@ -2,10 +2,10 @@ package querylog
import ( import (
"net" "net"
"net/http"
"path/filepath" "path/filepath"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -38,7 +38,7 @@ type Config struct {
ConfigModified func() ConfigModified func()
// HTTPRegister registers an HTTP handler. // HTTPRegister registers an HTTP handler.
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) HTTPRegister aghhttp.RegisterFunc
// FindClient returns client information by their IDs. // FindClient returns client information by their IDs.
FindClient func(ids []string) (c *Client, err error) FindClient func(ids []string) (c *Client, err error)

View File

@ -39,34 +39,21 @@ type statsResponse struct {
} }
// handleStats is a handler for getting statistics. // handleStats is a handler for getting statistics.
func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) { func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
start := time.Now() start := time.Now()
var resp statsResponse var resp statsResponse
if s.conf.limit == 0 { var ok bool
resp = statsResponse{ resp, ok = s.getData()
TimeUnits: "days",
TopBlocked: []topAddrs{}, log.Debug("stats: prepared data in %v", time.Since(start))
TopClients: []topAddrs{},
TopQueried: []topAddrs{},
BlockedFiltering: []uint64{}, if !ok {
DNSQueries: []uint64{}, // Don't bring the message to the lower case since it's a part of UI
ReplacedParental: []uint64{}, // text for the moment.
ReplacedSafebrowsing: []uint64{}, aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
}
} else {
var ok bool
resp, ok = s.getData()
log.Debug("stats: prepared data in %v", time.Since(start)) return
if !ok {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
return
}
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -84,9 +71,9 @@ type config struct {
} }
// Get configuration // Get configuration
func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
resp := config{} resp := config{}
resp.IntervalDays = s.conf.limit / 24 resp.IntervalDays = s.limitHours / 24
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
if err != nil { if err != nil {
@ -102,7 +89,7 @@ func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
} }
// Set configuration // Set configuration
func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
reqData := config{} reqData := config{}
err := json.NewDecoder(r.Body).Decode(&reqData) err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil { if err != nil {
@ -118,22 +105,22 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
} }
s.setLimit(int(reqData.IntervalDays)) s.setLimit(int(reqData.IntervalDays))
s.conf.ConfigModified() s.configModified()
} }
// Reset data // Reset data
func (s *statsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) { func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
s.clear() s.clear()
} }
// Register web handlers // Register web handlers
func (s *statsCtx) initWeb() { func (s *StatsCtx) initWeb() {
if s.conf.HTTPRegister == nil { if s.httpRegister == nil {
return return
} }
s.conf.HTTPRegister(http.MethodGet, "/control/stats", s.handleStats) s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
s.conf.HTTPRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset) s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.conf.HTTPRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig) s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
s.conf.HTTPRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo) s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
} }

View File

@ -4,75 +4,85 @@ package stats
import ( import (
"net" "net"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
) )
type unitIDCallback func() uint32 // UnitIDGenFunc is the signature of a function that generates a unique ID for
// the statistics unit.
type UnitIDGenFunc func() (id uint32)
// DiskConfig - configuration settings that are stored on disk // DiskConfig is the configuration structure that is stored in file.
type DiskConfig struct { type DiskConfig struct {
Interval uint32 `yaml:"statistics_interval"` // time interval for statistics (in days) // Interval is the number of days for which the statistics are collected
// before flushing to the database.
Interval uint32 `yaml:"statistics_interval"`
} }
// Config - module configuration // Config is the configuration structure for the statistics collecting.
type Config struct { type Config struct {
Filename string // database file name // UnitID is the function to generate the identifier for current unit. If
LimitDays uint32 // time limit (in days) // nil, the default function is used, see newUnitID.
UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used. UnitID UnitIDGenFunc
// Called when the configuration is changed by HTTP request // ConfigModified will be called each time the configuration changed via web
// interface.
ConfigModified func() ConfigModified func()
// Register an HTTP handler // HTTPRegister is the function that registers handlers for the stats
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) // endpoints.
HTTPRegister aghhttp.RegisterFunc
limit uint32 // maximum time we need to keep data for (in hours) // Filename is the name of the database file.
Filename string
// LimitDays is the maximum number of days to collect statistics into the
// current unit.
LimitDays uint32
} }
// New - create object // Interface is the statistics interface to be used by other packages.
func New(conf Config) (Stats, error) { type Interface interface {
return createObject(conf) // Start begins the statistics collecting.
}
// Stats - main interface
type Stats interface {
Start() Start()
// Close object. // Close stops the statistics collecting.
// This function is not thread safe
// (can't be called in parallel with any other function of this interface).
Close() Close()
// Update counters // Update collects the incoming statistics data.
Update(e Entry) Update(e Entry)
// Get IP addresses of the clients with the most number of requests // GetTopClientIP returns at most limit IP addresses corresponding to the
// clients with the most number of requests.
GetTopClientsIP(limit uint) []net.IP GetTopClientsIP(limit uint) []net.IP
// WriteDiskConfig - write configuration // WriteDiskConfig puts the Interface's configuration to the dc.
WriteDiskConfig(dc *DiskConfig) WriteDiskConfig(dc *DiskConfig)
} }
// TimeUnit - time unit // TimeUnit is the unit of measuring time while aggregating the statistics.
type TimeUnit int type TimeUnit int
// Supported time units // Supported TimeUnit values.
const ( const (
Hours TimeUnit = iota Hours TimeUnit = iota
Days Days
) )
// Result of DNS request processing // Result is the resulting code of processing the DNS request.
type Result int type Result int
// Supported result values // Supported Result values.
//
// TODO(e.burkov): Think about better naming.
const ( const (
RNotFiltered Result = iota + 1 RNotFiltered Result = iota + 1
RFiltered RFiltered
RSafeBrowsing RSafeBrowsing
RSafeSearch RSafeSearch
RParental RParental
rLast
resultLast = RParental + 1
) )
// Entry is a statistics data entry. // Entry is a statistics data entry.
@ -82,7 +92,12 @@ type Entry struct {
// TODO(a.garipov): Make this a {net.IP, string} enum? // TODO(a.garipov): Make this a {net.IP, string} enum?
Client string Client string
// Domain is the domain name requested.
Domain string Domain string
// Result is the result of processing the request.
Result Result Result Result
Time uint32 // processing time (msec)
// Time is the duration of the request processing in milliseconds.
Time uint32
} }

View File

@ -37,7 +37,7 @@ func TestStats(t *testing.T) {
LimitDays: 1, LimitDays: 1,
} }
s, err := createObject(conf) s, err := New(conf)
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.clear() s.clear()
@ -110,7 +110,7 @@ func TestLargeNumbers(t *testing.T) {
LimitDays: 1, LimitDays: 1,
UnitID: newID, UnitID: newID,
} }
s, err := createObject(conf) s, err := New(conf)
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) { testutil.CleanupAndRequireSuccess(t, func() (err error) {
s.Close() s.Close()

View File

@ -9,11 +9,13 @@ import (
"os" "os"
"sort" "sort"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
bolt "go.etcd.io/bbolt" "go.etcd.io/bbolt"
) )
// TODO(a.garipov): Rewrite all of this. Add proper error handling and // TODO(a.garipov): Rewrite all of this. Add proper error handling and
@ -24,47 +26,130 @@ const (
maxClients = 100 // max number of top clients to store in file or return via Get() maxClients = 100 // max number of top clients to store in file or return via Get()
) )
// statsCtx - global context // StatsCtx collects the statistics and flushes it to the database. Its default
type statsCtx struct { // flushing interval is one hour.
// mu protects unit. //
mu *sync.Mutex // TODO(e.burkov): Use atomic.Pointer for accessing curr and db in go1.19.
// current is the actual statistics collection result. type StatsCtx struct {
current *unit // currMu protects the current unit.
currMu *sync.Mutex
// curr is the actual statistics collection result.
curr *unit
db *bolt.DB // dbMu protects db.
conf *Config dbMu *sync.Mutex
// db is the opened statistics database, if any.
db *bbolt.DB
// unitIDGen is the function that generates an identifier for the current
// unit. It's here for only testing purposes.
unitIDGen UnitIDGenFunc
// httpRegister is used to set HTTP handlers.
httpRegister aghhttp.RegisterFunc
// configModified is called whenever the configuration is modified via web
// interface.
configModified func()
// filename is the name of database file.
filename string
// limitHours is the maximum number of hours to collect statistics into the
// current unit.
limitHours uint32
} }
// data for 1 time unit // unit collects the statistics data for a specific period of time.
type unit struct { type unit struct {
id uint32 // unit ID. Default: absolute hour since Jan 1, 1970 // mu protects all the fields of a unit.
mu *sync.RWMutex
nTotal uint64 // total requests // id is the unique unit's identifier. It's set to an absolute hour number
nResult []uint64 // number of requests per one result // since the beginning of UNIX time by the default ID generating function.
timeSum uint64 // sum of processing time of all requests (usec) id uint32
// top: // nTotal stores the total number of requests.
domains map[string]uint64 // number of requests per domain nTotal uint64
blockedDomains map[string]uint64 // number of blocked requests per domain // nResult stores the number of requests grouped by it's result.
clients map[string]uint64 // number of requests per client nResult []uint64
// timeSum stores the sum of processing time in milliseconds of each request
// written by the unit.
timeSum uint64
// domains stores the number of requests for each domain.
domains map[string]uint64
// blockedDomains stores the number of requests for each domain that has
// been blocked.
blockedDomains map[string]uint64
// clients stores the number of requests from each client.
clients map[string]uint64
} }
// name-count pair // ongoing returns the current unit. It's safe for concurrent use.
//
// Note that the unit itself should be locked before accessing.
func (s *StatsCtx) ongoing() (u *unit) {
s.currMu.Lock()
defer s.currMu.Unlock()
return s.curr
}
// swapCurrent swaps the current unit with another and returns it. It's safe
// for concurrent use.
func (s *StatsCtx) swapCurrent(with *unit) (old *unit) {
s.currMu.Lock()
defer s.currMu.Unlock()
old, s.curr = s.curr, with
return old
}
// database returns the database if it's opened. It's safe for concurrent use.
func (s *StatsCtx) database() (db *bbolt.DB) {
s.dbMu.Lock()
defer s.dbMu.Unlock()
return s.db
}
// swapDatabase swaps the database with another one and returns it. It's safe
// for concurrent use.
func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) {
s.dbMu.Lock()
defer s.dbMu.Unlock()
old, s.db = s.db, with
return old
}
// countPair is a single name-number pair for deserializing statistics data into
// the database.
type countPair struct { type countPair struct {
Name string Name string
Count uint64 Count uint64
} }
// structure for storing data in file // unitDB is the structure for deserializing statistics data into the database.
type unitDB struct { type unitDB struct {
NTotal uint64 // NTotal is the total number of requests.
NTotal uint64
// NResult is the number of requests by the result's kind.
NResult []uint64 NResult []uint64
Domains []countPair // Domains is the number of requests for each domain name.
Domains []countPair
// BlockedDomains is the number of requests blocked for each domain name.
BlockedDomains []countPair BlockedDomains []countPair
Clients []countPair // Clients is the number of requests from each client.
Clients []countPair
TimeAvg uint32 // usec // TimeAvg is the average of processing times in milliseconds of all the
// requests in the unit.
TimeAvg uint32
} }
// withRecovered turns the value recovered from panic if any into an error and // withRecovered turns the value recovered from panic if any into an error and
@ -86,34 +171,40 @@ func withRecovered(orig *error) {
*orig = errors.WithDeferred(*orig, err) *orig = errors.WithDeferred(*orig, err)
} }
// createObject creates s from conf and properly initializes it. // isEnabled is a helper that check if the statistics collecting is enabled.
func createObject(conf Config) (s *statsCtx, err error) { func (s *StatsCtx) isEnabled() (ok bool) {
return atomic.LoadUint32(&s.limitHours) != 0
}
// New creates s from conf and properly initializes it. Don't use s before
// calling it's Start method.
func New(conf Config) (s *StatsCtx, err error) {
defer withRecovered(&err) defer withRecovered(&err)
s = &statsCtx{ s = &StatsCtx{
mu: &sync.Mutex{}, currMu: &sync.Mutex{},
dbMu: &sync.Mutex{},
filename: conf.Filename,
configModified: conf.ConfigModified,
httpRegister: conf.HTTPRegister,
} }
if !checkInterval(conf.LimitDays) { if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
conf.LimitDays = 1 s.limitHours = 24
}
if s.unitIDGen = newUnitID; conf.UnitID != nil {
s.unitIDGen = conf.UnitID
} }
s.conf = &Config{} if err = s.dbOpen(); err != nil {
*s.conf = conf return nil, fmt.Errorf("opening database: %w", err)
s.conf.limit = conf.LimitDays * 24
if conf.UnitID == nil {
s.conf.UnitID = newUnitID
} }
if !s.dbOpen() { id := s.unitIDGen()
return nil, fmt.Errorf("open database") tx := beginTxn(s.db, true)
}
id := s.conf.UnitID()
tx := s.beginTxn(true)
var udb *unitDB var udb *unitDB
if tx != nil { if tx != nil {
log.Tracef("Deleting old units...") log.Tracef("Deleting old units...")
firstID := id - s.conf.limit - 1 firstID := id - s.limitHours - 1
unitDel := 0 unitDel := 0
err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID)) err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID))
@ -133,12 +224,11 @@ func createObject(conf Config) (s *statsCtx, err error) {
} }
} }
u := unit{} u := newUnit(id)
s.initUnit(&u, id) // This use of deserialize is safe since the accessed unit has just been
if udb != nil { // created.
deserialize(&u, udb) u.deserialize(udb)
} s.curr = u
s.current = &u
log.Debug("stats: initialized") log.Debug("stats: initialized")
@ -153,11 +243,11 @@ const errStop errors.Error = "stop iteration"
// integer that unitDelPtr points to is incremented for every successful // integer that unitDelPtr points to is incremented for every successful
// deletion. If the bucket isn't deleted, f returns errStop. // deletion. If the bucket isn't deleted, f returns errStop.
func newBucketWalker( func newBucketWalker(
tx *bolt.Tx, tx *bbolt.Tx,
unitDelPtr *int, unitDelPtr *int,
firstID uint32, firstID uint32,
) (f func(name []byte, b *bolt.Bucket) (err error)) { ) (f func(name []byte, b *bbolt.Bucket) (err error)) {
return func(name []byte, _ *bolt.Bucket) (err error) { return func(name []byte, _ *bbolt.Bucket) (err error) {
nameID, ok := unitNameToID(name) nameID, ok := unitNameToID(name)
if !ok || nameID < firstID { if !ok || nameID < firstID {
err = tx.DeleteBucket(name) err = tx.DeleteBucket(name)
@ -178,80 +268,92 @@ func newBucketWalker(
} }
} }
func (s *statsCtx) Start() { // Start makes s process the incoming data.
func (s *StatsCtx) Start() {
s.initWeb() s.initWeb()
go s.periodicFlush() go s.periodicFlush()
} }
func checkInterval(days uint32) bool { // checkInterval returns true if days is valid to be used as statistics
// retention interval. The valid values are 0, 1, 7, 30 and 90.
func checkInterval(days uint32) (ok bool) {
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90 return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
} }
func (s *statsCtx) dbOpen() bool { // dbOpen returns an error if the database can't be opened from the specified
var err error // file. It's safe for concurrent use.
func (s *StatsCtx) dbOpen() (err error) {
log.Tracef("db.Open...") log.Tracef("db.Open...")
s.db, err = bolt.Open(s.conf.Filename, 0o644, nil)
s.dbMu.Lock()
defer s.dbMu.Unlock()
s.db, err = bbolt.Open(s.filename, 0o644, nil)
if err != nil { if err != nil {
log.Error("stats: open DB: %s: %s", s.conf.Filename, err) log.Error("stats: open DB: %s: %s", s.filename, err)
if err.Error() == "invalid argument" { if err.Error() == "invalid argument" {
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations") log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
} }
return false
return err
} }
log.Tracef("db.Open") log.Tracef("db.Open")
return true
return nil
} }
// Atomically swap the currently active unit with a new value // newUnitID is the default UnitIDGenFunc that generates the unique id hourly.
// Return old value func newUnitID() (id uint32) {
func (s *statsCtx) swapUnit(new *unit) (u *unit) { const secsInHour = int64(time.Hour / time.Second)
s.mu.Lock()
defer s.mu.Unlock()
u = s.current return uint32(time.Now().Unix() / secsInHour)
s.current = new
return u
} }
// Get unit ID for the current hour // newUnit allocates the new *unit.
func newUnitID() uint32 { func newUnit(id uint32) (u *unit) {
return uint32(time.Now().Unix() / (60 * 60)) return &unit{
mu: &sync.RWMutex{},
id: id,
nResult: make([]uint64, resultLast),
domains: make(map[string]uint64),
blockedDomains: make(map[string]uint64),
clients: make(map[string]uint64),
}
} }
// Initialize a unit // beginTxn opens a new database transaction. If writable is true, the
func (s *statsCtx) initUnit(u *unit, id uint32) { // transaction will be opened for writing, and for reading otherwise. It
u.id = id // returns nil if the transaction can't be created.
u.nResult = make([]uint64, rLast) func beginTxn(db *bbolt.DB, writable bool) (tx *bbolt.Tx) {
u.domains = make(map[string]uint64)
u.blockedDomains = make(map[string]uint64)
u.clients = make(map[string]uint64)
}
// Open a DB transaction
func (s *statsCtx) beginTxn(wr bool) *bolt.Tx {
db := s.db
if db == nil { if db == nil {
return nil return nil
} }
log.Tracef("db.Begin...") log.Tracef("opening a database transaction")
tx, err := db.Begin(wr)
tx, err := db.Begin(writable)
if err != nil { if err != nil {
log.Error("db.Begin: %s", err) log.Error("stats: opening a transaction: %s", err)
return nil return nil
} }
log.Tracef("db.Begin")
log.Tracef("transaction has been opened")
return tx return tx
} }
func (s *statsCtx) commitTxn(tx *bolt.Tx) { // commitTxn applies the changes made in tx to the database.
func (s *StatsCtx) commitTxn(tx *bbolt.Tx) {
err := tx.Commit() err := tx.Commit()
if err != nil { if err != nil {
log.Debug("tx.Commit: %s", err) log.Error("stats: committing a transaction: %s", err)
return return
} }
log.Tracef("tx.Commit")
log.Tracef("transaction has been committed")
} }
// bucketNameLen is the length of a bucket, a 64-bit unsigned integer. // bucketNameLen is the length of a bucket, a 64-bit unsigned integer.
@ -262,10 +364,10 @@ const bucketNameLen = 8
// idToUnitName converts a numerical ID into a database unit name. // idToUnitName converts a numerical ID into a database unit name.
func idToUnitName(id uint32) (name []byte) { func idToUnitName(id uint32) (name []byte) {
name = make([]byte, bucketNameLen) n := [bucketNameLen]byte{}
binary.BigEndian.PutUint64(name, uint64(id)) binary.BigEndian.PutUint64(n[:], uint64(id))
return name return n[:]
} }
// unitNameToID converts a database unit name into a numerical ID. ok is false // unitNameToID converts a database unit name into a numerical ID. ok is false
@ -278,13 +380,6 @@ func unitNameToID(name []byte) (id uint32, ok bool) {
return uint32(binary.BigEndian.Uint64(name)), true return uint32(binary.BigEndian.Uint64(name)), true
} }
func (s *statsCtx) ongoing() (u *unit) {
s.mu.Lock()
defer s.mu.Unlock()
return s.current
}
// Flush the current unit to DB and delete an old unit when a new hour is started // Flush the current unit to DB and delete an old unit when a new hour is started
// If a unit must be flushed: // If a unit must be flushed:
// . lock DB // . lock DB
@ -293,34 +388,29 @@ func (s *statsCtx) ongoing() (u *unit) {
// . write the unit to DB // . write the unit to DB
// . remove the stale unit from DB // . remove the stale unit from DB
// . unlock DB // . unlock DB
func (s *statsCtx) periodicFlush() { func (s *StatsCtx) periodicFlush() {
for { for ptr := s.ongoing(); ptr != nil; ptr = s.ongoing() {
ptr := s.ongoing() id := s.unitIDGen()
if ptr == nil { // Access the unit's ID with atomic to avoid locking the whole unit.
break if !s.isEnabled() || atomic.LoadUint32(&ptr.id) == id {
}
id := s.conf.UnitID()
if ptr.id == id || s.conf.limit == 0 {
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
} }
tx := s.beginTxn(true) tx := beginTxn(s.database(), true)
nu := unit{} nu := newUnit(id)
s.initUnit(&nu, id) u := s.swapCurrent(nu)
u := s.swapUnit(&nu) udb := u.serialize()
udb := serialize(u)
if tx == nil { if tx == nil {
continue continue
} }
ok1 := s.flushUnitToDB(tx, u.id, udb) flushOK := flushUnitToDB(tx, u.id, udb)
ok2 := s.deleteUnit(tx, id-s.conf.limit) delOK := s.deleteUnit(tx, id-atomic.LoadUint32(&s.limitHours))
if ok1 || ok2 { if flushOK || delOK {
s.commitTxn(tx) s.commitTxn(tx)
} else { } else {
_ = tx.Rollback() _ = tx.Rollback()
@ -330,8 +420,8 @@ func (s *statsCtx) periodicFlush() {
log.Tracef("periodicFlush() exited") log.Tracef("periodicFlush() exited")
} }
// Delete unit's data from file // deleteUnit removes the unit by it's id from the database the tx belongs to.
func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool { func (s *StatsCtx) deleteUnit(tx *bbolt.Tx, id uint32) bool {
err := tx.DeleteBucket(idToUnitName(id)) err := tx.DeleteBucket(idToUnitName(id))
if err != nil { if err != nil {
log.Tracef("stats: bolt DeleteBucket: %s", err) log.Tracef("stats: bolt DeleteBucket: %s", err)
@ -347,10 +437,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
func convertMapToSlice(m map[string]uint64, max int) []countPair { func convertMapToSlice(m map[string]uint64, max int) []countPair {
a := []countPair{} a := []countPair{}
for k, v := range m { for k, v := range m {
pair := countPair{} a = append(a, countPair{Name: k, Count: v})
pair.Name = k
pair.Count = v
a = append(a, pair)
} }
less := func(i, j int) bool { less := func(i, j int) bool {
return a[j].Count < a[i].Count return a[j].Count < a[i].Count
@ -370,41 +457,46 @@ func convertSliceToMap(a []countPair) map[string]uint64 {
return m return m
} }
func serialize(u *unit) *unitDB { // serialize converts u to the *unitDB. It's safe for concurrent use.
udb := unitDB{} func (u *unit) serialize() (udb *unitDB) {
udb.NTotal = u.nTotal u.mu.RLock()
defer u.mu.RUnlock()
udb.NResult = append(udb.NResult, u.nResult...)
var timeAvg uint32 = 0
if u.nTotal != 0 { if u.nTotal != 0 {
udb.TimeAvg = uint32(u.timeSum / u.nTotal) timeAvg = uint32(u.timeSum / u.nTotal)
} }
udb.Domains = convertMapToSlice(u.domains, maxDomains) return &unitDB{
udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains) NTotal: u.nTotal,
udb.Clients = convertMapToSlice(u.clients, maxClients) NResult: append([]uint64{}, u.nResult...),
Domains: convertMapToSlice(u.domains, maxDomains),
return &udb BlockedDomains: convertMapToSlice(u.blockedDomains, maxDomains),
Clients: convertMapToSlice(u.clients, maxClients),
TimeAvg: timeAvg,
}
} }
func deserialize(u *unit, udb *unitDB) { // deserealize assigns the appropriate values from udb to u. u must not be nil.
// It's safe for concurrent use.
func (u *unit) deserialize(udb *unitDB) {
if udb == nil {
return
}
u.mu.Lock()
defer u.mu.Unlock()
u.nTotal = udb.NTotal u.nTotal = udb.NTotal
u.nResult = make([]uint64, resultLast)
n := len(udb.NResult) copy(u.nResult, udb.NResult)
if n < len(u.nResult) {
n = len(u.nResult) // n = min(len(udb.NResult), len(u.nResult))
}
for i := 1; i < n; i++ {
u.nResult[i] = udb.NResult[i]
}
u.domains = convertSliceToMap(udb.Domains) u.domains = convertSliceToMap(udb.Domains)
u.blockedDomains = convertSliceToMap(udb.BlockedDomains) u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
u.clients = convertSliceToMap(udb.Clients) u.clients = convertSliceToMap(udb.Clients)
u.timeSum = uint64(udb.TimeAvg) * u.nTotal u.timeSum = uint64(udb.TimeAvg) * udb.NTotal
} }
func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool { func flushUnitToDB(tx *bbolt.Tx, id uint32, udb *unitDB) bool {
log.Tracef("Flushing unit %d", id) log.Tracef("Flushing unit %d", id)
bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id)) bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id))
@ -430,7 +522,7 @@ func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool {
return true return true
} }
func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { func (s *StatsCtx) loadUnitFromDB(tx *bbolt.Tx, id uint32) *unitDB {
bkt := tx.Bucket(idToUnitName(id)) bkt := tx.Bucket(idToUnitName(id))
if bkt == nil { if bkt == nil {
return nil return nil
@ -451,44 +543,44 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
return &udb return &udb
} }
func convertTopSlice(a []countPair) []map[string]uint64 { func convertTopSlice(a []countPair) (m []map[string]uint64) {
m := []map[string]uint64{} m = make([]map[string]uint64, 0, len(a))
for _, it := range a { for _, it := range a {
ent := map[string]uint64{} m = append(m, map[string]uint64{it.Name: it.Count})
ent[it.Name] = it.Count
m = append(m, ent)
} }
return m return m
} }
func (s *statsCtx) setLimit(limitDays int) { func (s *StatsCtx) setLimit(limitDays int) {
s.conf.limit = uint32(limitDays) * 24 atomic.StoreUint32(&s.limitHours, uint32(24*limitDays))
if limitDays == 0 { if limitDays == 0 {
s.clear() s.clear()
} }
log.Debug("stats: set limit: %d", limitDays) log.Debug("stats: set limit: %d days", limitDays)
} }
func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) { func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
dc.Interval = s.conf.limit / 24 dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
} }
func (s *statsCtx) Close() { func (s *StatsCtx) Close() {
u := s.swapUnit(nil) u := s.swapCurrent(nil)
udb := serialize(u)
tx := s.beginTxn(true) db := s.database()
if tx != nil { if tx := beginTxn(db, true); tx != nil {
if s.flushUnitToDB(tx, u.id, udb) { udb := u.serialize()
if flushUnitToDB(tx, u.id, udb) {
s.commitTxn(tx) s.commitTxn(tx)
} else { } else {
_ = tx.Rollback() _ = tx.Rollback()
} }
} }
if s.db != nil { if db != nil {
log.Tracef("db.Close...") log.Tracef("db.Close...")
_ = s.db.Close() _ = db.Close()
log.Tracef("db.Close") log.Tracef("db.Close")
} }
@ -496,11 +588,11 @@ func (s *statsCtx) Close() {
} }
// Reset counters and clear database // Reset counters and clear database
func (s *statsCtx) clear() { func (s *StatsCtx) clear() {
tx := s.beginTxn(true) db := s.database()
tx := beginTxn(db, true)
if tx != nil { if tx != nil {
db := s.db _ = s.swapDatabase(nil)
s.db = nil
_ = tx.Rollback() _ = tx.Rollback()
// the active transactions can continue using database, // the active transactions can continue using database,
// but no new transactions will be opened // but no new transactions will be opened
@ -509,11 +601,10 @@ func (s *statsCtx) clear() {
// all active transactions are now closed // all active transactions are now closed
} }
u := unit{} u := newUnit(s.unitIDGen())
s.initUnit(&u, s.conf.UnitID()) _ = s.swapCurrent(u)
_ = s.swapUnit(&u)
err := os.Remove(s.conf.Filename) err := os.Remove(s.filename)
if err != nil { if err != nil {
log.Error("os.Remove: %s", err) log.Error("os.Remove: %s", err)
} }
@ -523,13 +614,13 @@ func (s *statsCtx) clear() {
log.Debug("stats: cleared") log.Debug("stats: cleared")
} }
func (s *statsCtx) Update(e Entry) { func (s *StatsCtx) Update(e Entry) {
if s.conf.limit == 0 { if !s.isEnabled() {
return return
} }
if e.Result == 0 || if e.Result == 0 ||
e.Result >= rLast || e.Result >= resultLast ||
e.Domain == "" || e.Domain == "" ||
e.Client == "" { e.Client == "" {
return return
@ -540,13 +631,15 @@ func (s *statsCtx) Update(e Entry) {
clientID = ip.String() clientID = ip.String()
} }
s.mu.Lock() u := s.ongoing()
defer s.mu.Unlock() if u == nil {
return
}
u := s.current u.mu.Lock()
defer u.mu.Unlock()
u.nResult[e.Result]++ u.nResult[e.Result]++
if e.Result == RNotFiltered { if e.Result == RNotFiltered {
u.domains[e.Domain]++ u.domains[e.Domain]++
} else { } else {
@ -558,14 +651,19 @@ func (s *statsCtx) Update(e Entry) {
u.nTotal++ u.nTotal++
} }
func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { func (s *StatsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
tx := s.beginTxn(false) tx := beginTxn(s.database(), false)
if tx == nil { if tx == nil {
return nil, 0 return nil, 0
} }
cur := s.ongoing() cur := s.ongoing()
curID := cur.id var curID uint32
if cur != nil {
curID = atomic.LoadUint32(&cur.id)
} else {
curID = s.unitIDGen()
}
// Per-hour units. // Per-hour units.
units := []*unitDB{} units := []*unitDB{}
@ -574,14 +672,16 @@ func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
u := s.loadUnitFromDB(tx, i) u := s.loadUnitFromDB(tx, i)
if u == nil { if u == nil {
u = &unitDB{} u = &unitDB{}
u.NResult = make([]uint64, rLast) u.NResult = make([]uint64, resultLast)
} }
units = append(units, u) units = append(units, u)
} }
_ = tx.Rollback() _ = tx.Rollback()
units = append(units, serialize(cur)) if cur != nil {
units = append(units, cur.serialize())
}
if len(units) != int(limit) { if len(units) != int(limit) {
log.Fatalf("len(units) != limit: %d %d", len(units), limit) log.Fatalf("len(units) != limit: %d %d", len(units), limit)
@ -628,13 +728,13 @@ func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsG
// pairsGetter is a signature for topsCollector argument. // pairsGetter is a signature for topsCollector argument.
type pairsGetter func(u *unitDB) (pairs []countPair) type pairsGetter func(u *unitDB) (pairs []countPair)
// topsCollector collects statistics about highest values fro the given *unitDB // topsCollector collects statistics about highest values from the given *unitDB
// slice using pg to retrieve data. // slice using pg to retrieve data.
func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 { func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 {
m := map[string]uint64{} m := map[string]uint64{}
for _, u := range units { for _, u := range units {
for _, it := range pg(u) { for _, cp := range pg(u) {
m[it.Name] += it.Count m[cp.Name] += cp.Count
} }
} }
a2 := convertMapToSlice(m, max) a2 := convertMapToSlice(m, max)
@ -668,8 +768,22 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64
* parental-blocked * parental-blocked
These values are just the sum of data for all units. These values are just the sum of data for all units.
*/ */
func (s *statsCtx) getData() (statsResponse, bool) { func (s *StatsCtx) getData() (statsResponse, bool) {
limit := s.conf.limit limit := atomic.LoadUint32(&s.limitHours)
if limit == 0 {
return statsResponse{
TimeUnits: "days",
TopBlocked: []topAddrs{},
TopClients: []topAddrs{},
TopQueried: []topAddrs{},
BlockedFiltering: []uint64{},
DNSQueries: []uint64{},
ReplacedParental: []uint64{},
ReplacedSafebrowsing: []uint64{},
}, true
}
timeUnit := Hours timeUnit := Hours
if limit/24 > 7 { if limit/24 > 7 {
@ -698,7 +812,7 @@ func (s *statsCtx) getData() (statsResponse, bool) {
// Total counters: // Total counters:
sum := unitDB{ sum := unitDB{
NResult: make([]uint64, rLast), NResult: make([]uint64, resultLast),
} }
timeN := 0 timeN := 0
for _, u := range units { for _, u := range units {
@ -731,12 +845,12 @@ func (s *statsCtx) getData() (statsResponse, bool) {
return data, true return data, true
} }
func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { func (s *StatsCtx) GetTopClientsIP(maxCount uint) []net.IP {
if s.conf.limit == 0 { if !s.isEnabled() {
return nil return nil
} }
units, _ := s.loadUnits(s.conf.limit) units, _ := s.loadUnits(atomic.LoadUint32(&s.limitHours))
if units == nil { if units == nil {
return nil return nil
} }