diff --git a/CHANGELOG.md b/CHANGELOG.md index ab942ad5..f1f1904d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to ### Added +- The ability to set the timeout for querying the upstream servers ([#2280]). - The ability to change group and user ID on startup on Unix ([#2763]). - Experimental OpenBSD support for AMD64 and 64-bit ARM CPUs ([#2439]). - Support for custom port in DNS-over-HTTPS profiles for Apple's devices @@ -60,6 +61,7 @@ released by then. - Go 1.15 support. +[#2280]: https://github.com/AdguardTeam/AdGuardHome/issues/2280 [#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439 [#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441 [#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443 diff --git a/HACKING.md b/HACKING.md index 151375ff..4c2ee095 100644 --- a/HACKING.md +++ b/HACKING.md @@ -197,13 +197,24 @@ attributes to make it work in Markdown renderers that strip "id". --> ### Formatting - * Decorate `break`, `continue`, `fallthrough`, `return`, and other function - exit points with empty lines unless it's the only statement in that block. + * Decorate `break`, `continue`, `fallthrough`, `return`, and other terminating + statements with empty lines unless it's the only statement in that block. * Don't group type declarations together. Unlike with blocks of `const`s, where a `iota` may be used or where all constants belong to a certain type, there is no reason to group `type`s. + * Group `require.*` blocks together with the presceding related statements, + but separate from the following `assert.*` and unrelated requirements. + + ```go + val, ok := valMap[key] + require.True(t, ok) + require.NotNil(t, val) + + assert.Equal(t, expected, val) + ``` + * Use `gofumpt --extra -s`. * Write slices of struct like this: diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 47339f7b..905e5cec 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -9,6 +9,7 @@ import ( "os" "sort" "strings" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" @@ -142,6 +143,9 @@ type ServerConfig struct { DNSCryptConfig TLSAllowUnencryptedDOH bool + // UpstreamTimeout is the timeout for querying upstream servers. + UpstreamTimeout time.Duration + TLSv12Roots *x509.CertPool // list of root CAs for TLSv1.2 TLSCiphers []uint16 // list of TLS ciphers to use @@ -261,6 +265,10 @@ func (s *Server) initDefaultSettings() { if len(s.conf.BlockedHosts) == 0 { s.conf.BlockedHosts = defaultBlockedHosts } + + if s.conf.UpstreamTimeout == 0 { + s.conf.UpstreamTimeout = DefaultTimeout + } } // prepareUpstreamSettings - prepares upstream DNS server settings @@ -299,7 +307,7 @@ func (s *Server) prepareUpstreamSettings() error { upstreams, upstream.Options{ Bootstrap: s.conf.BootstrapDNS, - Timeout: DefaultTimeout, + Timeout: s.conf.UpstreamTimeout, }, ) if err != nil { @@ -313,7 +321,7 @@ func (s *Server) prepareUpstreamSettings() error { defaultDNS, upstream.Options{ Bootstrap: s.conf.BootstrapDNS, - Timeout: DefaultTimeout, + Timeout: s.conf.UpstreamTimeout, }, ) if err != nil { diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 55b9904f..a03cc87f 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -279,6 +279,34 @@ func TestServer(t *testing.T) { } } +func TestServer_timeout(t *testing.T) { + const timeout time.Duration = time.Second + + t.Run("custom", func(t *testing.T) { + srvConf := &ServerConfig{ + UpstreamTimeout: timeout, + } + + s, err := NewServer(DNSCreateParams{}) + require.NoError(t, err) + + err = s.Prepare(srvConf) + require.NoError(t, err) + + assert.Equal(t, timeout, s.conf.UpstreamTimeout) + }) + + t.Run("default", func(t *testing.T) { + s, err := NewServer(DNSCreateParams{}) + require.NoError(t, err) + + err = s.Prepare(nil) + require.NoError(t, err) + + assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout) + }) +} + func TestServerWithProtectionDisabled(t *testing.T) { s := createTestServer(t, &filtering.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 28c0dd47..06baa9a9 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -7,6 +7,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" @@ -529,7 +530,7 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) { return nil } -func checkDNS(input string, bootstrap []string, ef excFunc) (err error) { +func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) { if aghstrings.IsCommentOrEmpty(input) { return nil } @@ -557,7 +558,7 @@ func checkDNS(input string, bootstrap []string, ef excFunc) (err error) { var u upstream.Upstream u, err = upstream.AddressToUpstream(input, upstream.Options{ Bootstrap: bootstrap, - Timeout: DefaultTimeout, + Timeout: timeout, }) if err != nil { return fmt.Errorf("failed to choose upstream for %q: %w", input, err) @@ -584,8 +585,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { result := map[string]string{} bootstraps := req.BootstrapDNS + timeout := s.conf.UpstreamTimeout for _, host := range req.Upstreams { - err = checkDNS(host, bootstraps, checkDNSUpstreamExc) + err = checkDNS(host, bootstraps, timeout, checkDNSUpstreamExc) if err != nil { log.Info("%v", err) result[host] = err.Error() @@ -597,7 +599,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } for _, host := range req.PrivateUpstreams { - err = checkDNS(host, bootstraps, checkPrivateUpstreamExc) + err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc) if err != nil { log.Info("%v", err) // TODO(e.burkov): If passed upstream have already diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index fc34a3b9..99fece08 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -18,7 +18,8 @@ import ( "github.com/stretchr/testify/require" ) -// fakeSystemResolvers is a mock aghnet.SystemResolvers implementation for tests. +// fakeSystemResolvers is a mock aghnet.SystemResolvers implementation for +// tests. type fakeSystemResolvers struct { // SystemResolvers is embedded here simply to make *fakeSystemResolvers // an aghnet.SystemResolvers without actually implementing all methods. diff --git a/internal/home/authratelimiter.go b/internal/home/authratelimiter.go index c0b3da40..acdee35c 100644 --- a/internal/home/authratelimiter.go +++ b/internal/home/authratelimiter.go @@ -72,7 +72,7 @@ func (ab *authRateLimiter) check(usrID string) (left time.Duration) { // incLocked increments the number of unsuccessful attempts for attempter with // ip and updates it's blocking moment if needed. For internal use only. func (ab *authRateLimiter) incLocked(usrID string, now time.Time) { - var until time.Time = now.Add(failedAuthTTL) + until := now.Add(failedAuthTTL) var attNum uint = 1 a, ok := ab.failedAuths[usrID] diff --git a/internal/home/clients.go b/internal/home/clients.go index c1a7d499..9ef4a9ee 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -361,7 +361,7 @@ func (clients *clientsContainer) findUpstreams( upstreams, upstream.Options{ Bootstrap: config.DNS.BootstrapDNS, - Timeout: dnsforward.DefaultTimeout, + Timeout: config.DNS.UpstreamTimeout.Duration, }, ) if err != nil { diff --git a/internal/home/config.go b/internal/home/config.go index 82f89178..d660ddb8 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -114,6 +114,9 @@ type dnsConfig struct { FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours) DnsfilterConf filtering.Config `yaml:",inline"` + // UpstreamTimeout is the timeout for querying upstream servers. + UpstreamTimeout Duration `yaml:"upstream_timeout"` + // LocalDomainName is the domain name used for known internal hosts. // For example, a machine called "myhost" can be addressed as // "myhost.lan" when LocalDomainName is "lan". @@ -182,6 +185,7 @@ var config = configuration{ }, FilteringEnabled: true, // whether or not use filter lists FiltersUpdateIntervalHours: 24, + UpstreamTimeout: Duration{dnsforward.DefaultTimeout}, LocalDomainName: "lan", ResolveClients: true, UsePrivateRDNS: true, @@ -276,6 +280,10 @@ func parseConfig() error { config.DNS.FiltersUpdateIntervalHours = 24 } + if config.DNS.UpstreamTimeout.Duration == 0 { + config.DNS.UpstreamTimeout = Duration{dnsforward.DefaultTimeout} + } + return nil } diff --git a/internal/home/dns.go b/internal/home/dns.go index f4092a95..0c4236dd 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -202,6 +202,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { newConf.ResolveClients = dnsConf.ResolveClients newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers + newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration return newConf, nil } diff --git a/internal/home/duration.go b/internal/home/duration.go new file mode 100644 index 00000000..c5a2a751 --- /dev/null +++ b/internal/home/duration.go @@ -0,0 +1,28 @@ +package home + +import ( + "time" + + "github.com/AdguardTeam/golibs/errors" +) + +// Duration is a wrapper for time.Duration providing functionality for encoding. +type Duration struct { + // time.Duration is embedded here to avoid implementing all the methods. + time.Duration +} + +// MarshalText implements the encoding.TextMarshaler interface for Duration. +func (d Duration) MarshalText() (text []byte, err error) { + return []byte(d.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for +// *Duration. +func (d *Duration) UnmarshalText(b []byte) (err error) { + defer func() { err = errors.Annotate(err, "unmarshalling duration: %w") }() + + d.Duration, err = time.ParseDuration(string(b)) + + return err +} diff --git a/internal/home/duration_test.go b/internal/home/duration_test.go new file mode 100644 index 00000000..8a9ad215 --- /dev/null +++ b/internal/home/duration_test.go @@ -0,0 +1,193 @@ +package home + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yaml "gopkg.in/yaml.v2" +) + +// durationEncodingTester is a helper struct to simplify testing different +// Duration marshalling and unmarshalling cases. +type durationEncodingTester struct { + PtrMap map[string]*Duration `json:"ptr_map" yaml:"ptr_map"` + PtrSlice []*Duration `json:"ptr_slice" yaml:"ptr_slice"` + PtrValue *Duration `json:"ptr_value" yaml:"ptr_value"` + PtrArray [1]*Duration `json:"ptr_array" yaml:"ptr_array"` + Map map[string]Duration `json:"map" yaml:"map"` + Slice []Duration `json:"slice" yaml:"slice"` + Value Duration `json:"value" yaml:"value"` + Array [1]Duration `json:"array" yaml:"array"` +} + +const nl = "\n" +const ( + jsonStr = `{` + + `"ptr_map":{"dur":"1ms"},` + + `"ptr_slice":["1ms"],` + + `"ptr_value":"1ms",` + + `"ptr_array":["1ms"],` + + `"map":{"dur":"1ms"},` + + `"slice":["1ms"],` + + `"value":"1ms",` + + `"array":["1ms"]` + + `}` + yamlStr = `ptr_map:` + nl + + ` dur: 1ms` + nl + + `ptr_slice:` + nl + + `- 1ms` + nl + + `ptr_value: 1ms` + nl + + `ptr_array:` + nl + + `- 1ms` + nl + + `map:` + nl + + ` dur: 1ms` + nl + + `slice:` + nl + + `- 1ms` + nl + + `value: 1ms` + nl + + `array:` + nl + + `- 1ms` +) + +// defaultTestDur is the default time.Duration value to be used throughout the tests of +// Duration. +const defaultTestDur = time.Millisecond + +// checkFields verifies m's fields. It expects the m to be unmarshalled from +// one of the constant strings above. +func (m *durationEncodingTester) checkFields(t *testing.T, d Duration) { + t.Run("pointers_map", func(t *testing.T) { + require.NotNil(t, m.PtrMap) + + fromPtrMap, ok := m.PtrMap["dur"] + require.True(t, ok) + require.NotNil(t, fromPtrMap) + + assert.Equal(t, d, *fromPtrMap) + }) + + t.Run("pointers_slice", func(t *testing.T) { + require.Len(t, m.PtrSlice, 1) + + fromPtrSlice := m.PtrSlice[0] + require.NotNil(t, fromPtrSlice) + + assert.Equal(t, d, *fromPtrSlice) + }) + + t.Run("pointers_array", func(t *testing.T) { + fromPtrArray := m.PtrArray[0] + require.NotNil(t, fromPtrArray) + + assert.Equal(t, d, *fromPtrArray) + }) + + t.Run("pointer_value", func(t *testing.T) { + require.NotNil(t, m.PtrValue) + + assert.Equal(t, d, *m.PtrValue) + }) + + t.Run("map", func(t *testing.T) { + fromMap, ok := m.Map["dur"] + require.True(t, ok) + + assert.Equal(t, d, fromMap) + }) + + t.Run("slice", func(t *testing.T) { + require.Len(t, m.Slice, 1) + + assert.Equal(t, d, m.Slice[0]) + }) + + t.Run("array", func(t *testing.T) { + assert.Equal(t, d, m.Array[0]) + }) + + t.Run("value", func(t *testing.T) { + assert.Equal(t, d, m.Value) + }) +} + +func TestDuration_MarshalText(t *testing.T) { + d := Duration{defaultTestDur} + dPtr := &d + + v := durationEncodingTester{ + PtrMap: map[string]*Duration{"dur": dPtr}, + PtrSlice: []*Duration{dPtr}, + PtrValue: dPtr, + PtrArray: [1]*Duration{dPtr}, + Map: map[string]Duration{"dur": d}, + Slice: []Duration{d}, + Value: d, + Array: [1]Duration{d}, + } + + b := &bytes.Buffer{} + t.Run("json", func(t *testing.T) { + t.Cleanup(b.Reset) + err := json.NewEncoder(b).Encode(v) + require.NoError(t, err) + + assert.JSONEq(t, jsonStr, b.String()) + }) + + t.Run("yaml", func(t *testing.T) { + t.Cleanup(b.Reset) + err := yaml.NewEncoder(b).Encode(v) + require.NoError(t, err) + + assert.YAMLEq(t, yamlStr, b.String(), b.String()) + }) + + t.Run("direct", func(t *testing.T) { + data, err := d.MarshalText() + require.NoError(t, err) + + assert.EqualValues(t, []byte(defaultTestDur.String()), data) + }) +} + +func TestDuration_UnmarshalText(t *testing.T) { + d := Duration{defaultTestDur} + var v *durationEncodingTester + + t.Run("json", func(t *testing.T) { + v = &durationEncodingTester{} + + r := strings.NewReader(jsonStr) + err := json.NewDecoder(r).Decode(v) + require.NoError(t, err) + + v.checkFields(t, d) + }) + + t.Run("yaml", func(t *testing.T) { + v = &durationEncodingTester{} + + r := strings.NewReader(yamlStr) + err := yaml.NewDecoder(r).Decode(v) + require.NoError(t, err) + + v.checkFields(t, d) + }) + + t.Run("direct", func(t *testing.T) { + dd := &Duration{} + + err := dd.UnmarshalText([]byte(d.String())) + require.NoError(t, err) + + assert.Equal(t, d, *dd) + }) + + t.Run("bad_data", func(t *testing.T) { + assert.Error(t, (&Duration{}).UnmarshalText([]byte(`abc`))) + }) +}