Pull request 1812: AG-21286

Merge in DNS/adguard-home from AG-21286 to master

Squashed commit of the following:

commit 587b4a3704fd63aa3da6c1be83f8a49bf4e27b00
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Apr 11 14:15:05 2023 +0300

    all: fix negative pause duration
This commit is contained in:
Ainar Garipov 2023-04-11 15:02:29 +03:00
parent 9e14d5f99f
commit 950ecb1f5e
4 changed files with 41 additions and 34 deletions

View File

@ -648,17 +648,17 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
// UpdatedProtectionStatus updates protection state, if the protection was // UpdatedProtectionStatus updates protection state, if the protection was
// disabled temporarily. Returns the updated state of protection. // disabled temporarily. Returns the updated state of protection.
func (s *Server) UpdatedProtectionStatus() (enabled bool) { func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
disabledUntil := s.conf.ProtectionDisabledUntil disabledUntil = s.conf.ProtectionDisabledUntil
if disabledUntil == nil { if disabledUntil == nil {
return s.conf.ProtectionEnabled return s.conf.ProtectionEnabled, nil
} }
if time.Now().Before(*disabledUntil) { if time.Now().Before(*disabledUntil) {
return false return false, disabledUntil
} }
// Update the values in a separate goroutine, unless an update is already in // Update the values in a separate goroutine, unless an update is already in
@ -671,7 +671,7 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool) {
go s.enableProtectionAfterPause() go s.enableProtectionAfterPause()
} }
return true return true, nil
} }
// enableProtectionAfterPause sets the protection configuration to enabled // enableProtectionAfterPause sets the protection configuration to enabled

View File

@ -206,7 +206,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
dctx.clientID = string(s.clientIDCache.Get(key[:])) dctx.clientID = string(s.clientIDCache.Get(key[:]))
// Get the client-specific filtering settings. // Get the client-specific filtering settings.
dctx.protectionEnabled = s.UpdatedProtectionStatus() dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
dctx.setts = s.getClientRequestFilteringSettings(dctx) dctx.setts = s.getClientRequestFilteringSettings(dctx)
return resultCodeSuccess return resultCodeSuccess
@ -460,7 +460,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
} }
// indexFirstV4Label returns the index at which the reversed IPv4 address // indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuiming the domain is pre-validated ARPA domain having in-addr and // starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed. // arpa labels removed.
func indexFirstV4Label(domain string) (idx int) { func indexFirstV4Label(domain string) (idx int) {
idx = len(domain) idx = len(domain)
@ -478,7 +478,7 @@ func indexFirstV4Label(domain string) (idx int) {
} }
// indexFirstV6Label returns the index at which the reversed IPv6 address // indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuiming the domain is pre-validated ARPA domain having ip6 and arpa // starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed. // labels removed.
func indexFirstV6Label(domain string) (idx int) { func indexFirstV6Label(domain string) (idx int) {
idx = len(domain) idx = len(domain)

View File

@ -101,7 +101,7 @@ type jsonDNSConfig struct {
} }
func (s *Server) getDNSConfig() (c *jsonDNSConfig) { func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
protectionEnabled := s.UpdatedProtectionStatus() protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
@ -128,12 +128,6 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
usePrivateRDNS := s.conf.UsePrivateRDNS usePrivateRDNS := s.conf.UsePrivateRDNS
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers) localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
var disabledUntil *time.Time
if s.conf.ProtectionDisabledUntil != nil {
t := *s.conf.ProtectionDisabledUntil
disabledUntil = &t
}
var upstreamMode string var upstreamMode string
if s.conf.FastestAddr { if s.conf.FastestAddr {
upstreamMode = "fastest_addr" upstreamMode = "fastest_addr"
@ -169,7 +163,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
UsePrivateRDNS: &usePrivateRDNS, UsePrivateRDNS: &usePrivateRDNS,
LocalPTRUpstreams: &localPTRUpstreams, LocalPTRUpstreams: &localPTRUpstreams,
DefaultLocalPTRUpstreams: defLocalPTRUps, DefaultLocalPTRUpstreams: defLocalPTRUps,
DisabledUntil: disabledUntil, DisabledUntil: protectionDisabledUntil,
} }
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/NYTimes/gziphandler" "github.com/NYTimes/gziphandler"
) )
@ -98,14 +99,17 @@ func collectDNSAddresses() (addrs []string, err error) {
// statusResponse is a response for /control/status endpoint. // statusResponse is a response for /control/status endpoint.
type statusResponse struct { type statusResponse struct {
Version string `json:"version"` Version string `json:"version"`
Language string `json:"language"` Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"` DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"` DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"` HTTPPort int `json:"http_port"`
IsProtectionEnabled bool `json:"protection_enabled"`
// ProtectionDisabledDuration is a pause duration in milliseconds. // ProtectionDisabledDuration is the duration of the protection pause in
// milliseconds.
ProtectionDisabledDuration int64 `json:"protection_disabled_duration"` ProtectionDisabledDuration int64 `json:"protection_disabled_duration"`
ProtectionEnabled bool `json:"protection_enabled"`
// TODO(e.burkov): Inspect if front-end doesn't requires this field as // TODO(e.burkov): Inspect if front-end doesn't requires this field as
// openapi.yaml declares. // openapi.yaml declares.
IsDHCPAvailable bool `json:"dhcp_available"` IsDHCPAvailable bool `json:"dhcp_available"`
@ -122,12 +126,15 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
return return
} }
isProtectionEnabled := false var (
var c *dnsforward.FilteringConfig fltConf *dnsforward.FilteringConfig
protectionDisabledUntil *time.Time
protectionEnabled bool
)
if Context.dnsServer != nil { if Context.dnsServer != nil {
c = &dnsforward.FilteringConfig{} fltConf = &dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(c) Context.dnsServer.WriteDiskConfig(fltConf)
isProtectionEnabled = Context.dnsServer.UpdatedProtectionStatus() protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
} }
var resp statusResponse var resp statusResponse
@ -135,20 +142,26 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
config.RLock() config.RLock()
defer config.RUnlock() defer config.RUnlock()
var pauseDuration int64 var protectionDisabledDuration int64
if until := config.DNS.ProtectionDisabledUntil; until != nil { if protectionDisabledUntil != nil {
pauseDuration = time.Until(*until).Milliseconds() // Make sure that we don't send negative numbers to the frontend,
// since enough time might have passed to make the difference less
// than zero.
protectionDisabledDuration = mathutil.Max(
0,
time.Until(*protectionDisabledUntil).Milliseconds(),
)
} }
resp = statusResponse{ resp = statusResponse{
Version: version.Version(), Version: version.Version(),
Language: config.Language,
DNSAddrs: dnsAddrs, DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port, DNSPort: config.DNS.Port,
HTTPPort: config.BindPort, HTTPPort: config.BindPort,
Language: config.Language, ProtectionDisabledDuration: protectionDisabledDuration,
ProtectionEnabled: protectionEnabled,
IsRunning: isRunning(), IsRunning: isRunning(),
ProtectionDisabledDuration: pauseDuration,
IsProtectionEnabled: isProtectionEnabled,
} }
}() }()