diff --git a/CHANGELOG.md b/CHANGELOG.md index f86766af..a8c5b424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ released by then. ### Fixed +- Custom upstreams selection for clients with client IDs in DNS-over-TLS and + DNS-over-HTTP ([#3186]). - Incorrect client-based filtering applying logic ([#2875]). ### Removed @@ -40,6 +42,7 @@ released by then. [#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184 [#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185 +[#3186]: https://github.com/AdguardTeam/AdGuardHome/issues/3186 diff --git a/internal/aghnet/addr.go b/internal/aghnet/addr.go index 09971955..336cd136 100644 --- a/internal/aghnet/addr.go +++ b/internal/aghnet/addr.go @@ -10,6 +10,19 @@ import ( "golang.org/x/net/idna" ) +// IPFromAddr returns an IP address from addr. If addr is neither +// a *net.TCPAddr nor a *net.UDPAddr, it returns nil. +func IPFromAddr(addr net.Addr) (ip net.IP) { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.IP + case *net.UDPAddr: + return addr.IP + } + + return nil +} + // IsValidHostOuterRune returns true if r is a valid initial or final rune for // a hostname label. func IsValidHostOuterRune(r rune) (ok bool) { diff --git a/internal/aghnet/addr_test.go b/internal/aghnet/addr_test.go index 514f8706..df9cd740 100644 --- a/internal/aghnet/addr_test.go +++ b/internal/aghnet/addr_test.go @@ -9,6 +9,14 @@ import ( "github.com/stretchr/testify/require" ) +func TestIPFromAddr(t *testing.T) { + ip := net.IP{1, 2, 3, 4} + assert.Equal(t, net.IP(nil), IPFromAddr(nil)) + assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{})) + assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip})) + assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip})) +} + func TestValidateHardwareAddress(t *testing.T) { testCases := []struct { name string diff --git a/internal/aghstrings/strings.go b/internal/aghstrings/strings.go index c4993fd5..201a319c 100644 --- a/internal/aghstrings/strings.go +++ b/internal/aghstrings/strings.go @@ -19,6 +19,19 @@ func CloneSlice(a []string) (b []string) { return CloneSliceOrEmpty(a) } +// Coalesce returns the first non-empty string. It is named after the function +// COALESCE in SQL except that since strings in Go are non-nullable, it uses an +// empty string as a NULL value. If strs is empty, it returns an empty string. +func Coalesce(strs ...string) (res string) { + for _, s := range strs { + if s != "" { + return s + } + } + + return "" +} + // FilterOut returns a copy of strs with all strings for which f returned true // removed. func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) { diff --git a/internal/aghstrings/strings_test.go b/internal/aghstrings/strings_test.go index 3cb5723f..78cb2923 100644 --- a/internal/aghstrings/strings_test.go +++ b/internal/aghstrings/strings_test.go @@ -36,6 +36,14 @@ func TestCloneSlice_family(t *testing.T) { }) } +func TestCoalesce(t *testing.T) { + assert.Equal(t, "", Coalesce()) + assert.Equal(t, "a", Coalesce("a")) + assert.Equal(t, "a", Coalesce("", "a")) + assert.Equal(t, "a", Coalesce("a", "")) + assert.Equal(t, "a", Coalesce("a", "b")) +} + func TestFilterOut(t *testing.T) { strs := []string{ "1.2.3.4", diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 0fea1bd9..47339f7b 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -8,7 +8,9 @@ import ( "net/http" "os" "sort" + "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" @@ -27,11 +29,10 @@ type FilteringConfig struct { // FilterHandler is an optional additional filtering callback. FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"` - // GetCustomUpstreamByClient - a callback function that returns upstreams configuration - // based on the client IP address. Returns nil if there are no custom upstreams for the client - // - // TODO(e.burkov): Replace argument type with net.IP. - GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` + // GetCustomUpstreamByClient is a callback that returns upstreams + // configuration based on the client IP address or ClientID. It returns + // nil if there are no custom upstreams for the client. + GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"` // Protection configuration // -- @@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { return nil } +// isInSorted returns true if s is in the sorted slice strs. +func isInSorted(strs []string, s string) (ok bool) { + i := sort.SearchStrings(strs, s) + if i == len(strs) || strs[i] != s { + return false + } + + return true +} + +// isWildcard returns true if host is a wildcard hostname. +func isWildcard(host string) (ok bool) { + return len(host) >= 2 && host[0] == '*' && host[1] == '.' +} + +// matchesDomainWildcard returns true if host matches the domain wildcard +// pattern pat. +func matchesDomainWildcard(host, pat string) (ok bool) { + return isWildcard(pat) && strings.HasSuffix(host, pat[1:]) +} + +// anyNameMatches returns true if sni, the client's SNI value, matches any of +// the DNS names and patterns from certificate. dnsNames must be sorted. +func anyNameMatches(dnsNames []string, sni string) (ok bool) { + if aghnet.ValidateDomainName(sni) != nil { + return false + } + + if isInSorted(dnsNames, sni) { + return true + } + + for _, dn := range dnsNames { + if matchesDomainWildcard(sni, dn) { + return true + } + } + + return false +} + // Called by 'tls' package when Client Hello is received // If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { - if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { + if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) { log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName) return nil, fmt.Errorf("invalid SNI") } diff --git a/internal/dnsforward/config_test.go b/internal/dnsforward/config_test.go new file mode 100644 index 00000000..f98e2c22 --- /dev/null +++ b/internal/dnsforward/config_test.go @@ -0,0 +1,53 @@ +package dnsforward + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAnyNameMatches(t *testing.T) { + dnsNames := []string{"host1", "*.host2", "1.2.3.4"} + sort.Strings(dnsNames) + + testCases := []struct { + name string + dnsName string + want bool + }{{ + name: "match", + dnsName: "host1", + want: true, + }, { + name: "match", + dnsName: "a.host2", + want: true, + }, { + name: "match", + dnsName: "b.a.host2", + want: true, + }, { + name: "match", + dnsName: "1.2.3.4", + want: true, + }, { + name: "mismatch", + dnsName: "host2", + want: false, + }, { + name: "mismatch", + dnsName: "", + want: false, + }, { + name: "mismatch", + dnsName: "*.host2", + want: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName)) + }) + } +} diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 92780057..76196da2 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -6,6 +6,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" @@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { rc = resultCodeSuccess var ip net.IP - if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil { + if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil { return rc } @@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } +// ipStringFromAddr extracts an IP address string from net.Addr. +func ipStringFromAddr(addr net.Addr) (ipStr string) { + if ip := aghnet.IPFromAddr(addr); ip != nil { + return ip.String() + } + + return "" +} + // processUpstream passes request to upstream servers and handles the response. func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx @@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) { } if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { - clientIP := IPStringFromAddr(d.Addr) - if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil { - log.Debug("dns: using custom upstreams for client %s", clientIP) + // Use the clientID first, since it has a higher priority. + id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr)) + upsConf, err := s.conf.GetCustomUpstreamByClient(id) + if err != nil { + log.Error("dns: getting custom upstreams for client %s: %s", id, err) + } else if upsConf != nil { + log.Debug("dns: using custom upstreams for client %s", id) d.CustomUpstreamConfig = upsConf } } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 06a3c0e1..565844b5 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { require.Empty(t, proxyCtx.Res.Answer) }) } + +func TestIPStringFromAddr(t *testing.T) { + t.Run("not_nil", func(t *testing.T) { + addr := net.UDPAddr{ + IP: net.ParseIP("1:2:3::4"), + Port: 12345, + Zone: "eth0", + } + assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String()) + }) + + t.Run("nil", func(t *testing.T) { + assert.Empty(t, ipStringFromAddr(nil)) + }) +} diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 6ced4805..55b9904f 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -12,7 +12,6 @@ import ( "math/big" "net" "os" - "sort" "sync" "testing" "time" @@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) { }, } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) - s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { - return &proxy.UpstreamConfig{ - Upstreams: []upstream.Upstream{ - &aghtest.TestUpstream{ - IPv4: map[string][]net.IP{ - "host.": {{192, 168, 0, 1}}, - }, - }, + s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { + ups := &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ + "host.": {{192, 168, 0, 1}}, }, } + + return &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{ups}, + }, nil } startDeferStop(t, s) @@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} { } } -func TestIPStringFromAddr(t *testing.T) { - t.Run("not_nil", func(t *testing.T) { - addr := net.UDPAddr{ - IP: net.ParseIP("1:2:3::4"), - Port: 12345, - Zone: "eth0", - } - assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String()) - }) - - t.Run("nil", func(t *testing.T) { - assert.Empty(t, IPStringFromAddr(nil)) - }) -} - -func TestMatchDNSName(t *testing.T) { - dnsNames := []string{"host1", "*.host2", "1.2.3.4"} - sort.Strings(dnsNames) - - testCases := []struct { - name string - dnsName string - want bool - }{{ - name: "match", - dnsName: "host1", - want: true, - }, { - name: "match", - dnsName: "a.host2", - want: true, - }, { - name: "match", - dnsName: "b.a.host2", - want: true, - }, { - name: "match", - dnsName: "1.2.3.4", - want: true, - }, { - name: "mismatch", - dnsName: "host2", - want: false, - }, { - name: "mismatch", - dnsName: "", - want: false, - }, { - name: "mismatch", - dnsName: "*.host2", - want: false, - }} - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName)) - }) - } -} - type testDHCP struct{} func (d *testDHCP) Enabled() (ok bool) { return true } diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 3dd3530d..69ffa9a6 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -4,15 +4,15 @@ import ( "fmt" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := IPFromAddr(d.Addr) + ip := aghnet.IPFromAddr(d.Addr) disallowed, _ := s.access.IsBlockedIP(ip) if disallowed { log.Tracef("Client IP %s is blocked by settings", ip) @@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { setts := s.dnsFilter.GetConfig() if s.conf.FilterHandler != nil { - s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts) + s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts) } return &setts diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index d0eb39ca..dc57d930 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -4,6 +4,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" @@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: IPFromAddr(pctx.Addr), + ClientIP: aghnet.IPFromAddr(pctx.Addr), ClientID: ctx.clientID, } @@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri if clientID := ctx.clientID; clientID != "" { e.Client = clientID - } else if ip := IPFromAddr(pctx.Addr); ip != nil { + } else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil { e.Client = ip.String() } diff --git a/internal/dnsforward/util.go b/internal/dnsforward/util.go deleted file mode 100644 index 5fdde967..00000000 --- a/internal/dnsforward/util.go +++ /dev/null @@ -1,69 +0,0 @@ -package dnsforward - -import ( - "net" - "sort" - "strings" - - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" -) - -// IPFromAddr gets IP address from addr. -func IPFromAddr(addr net.Addr) (ip net.IP) { - switch addr := addr.(type) { - case *net.UDPAddr: - return addr.IP - case *net.TCPAddr: - return addr.IP - } - return nil -} - -// IPStringFromAddr extracts IP address from net.Addr. -// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: -// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261 -func IPStringFromAddr(addr net.Addr) (ipStr string) { - if ip := IPFromAddr(addr); ip != nil { - return ip.String() - } - - return "" -} - -// Find value in a sorted array -func findSorted(ar []string, val string) int { - i := sort.SearchStrings(ar, val) - if i == len(ar) || ar[i] != val { - return -1 - } - return i -} - -func isWildcard(host string) bool { - return len(host) >= 2 && - host[0] == '*' && host[1] == '.' -} - -// Return TRUE if host name matches a wildcard pattern -func matchDomainWildcard(host, wildcard string) bool { - return isWildcard(wildcard) && - strings.HasSuffix(host, wildcard[1:]) -} - -// Return TRUE if client's SNI value matches DNS names from certificate -func matchDNSName(dnsNames []string, sni string) bool { - if aghnet.ValidateDomainName(sni) != nil { - return false - } - - if findSorted(dnsNames, sni) != -1 { - return true - } - - for _, dn := range dnsNames { - if matchDomainWildcard(sni, dn) { - return true - } - } - return false -} diff --git a/internal/dnsforward/util_test.go b/internal/dnsforward/util_test.go deleted file mode 100644 index 09cdf714..00000000 --- a/internal/dnsforward/util_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package dnsforward - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -// fakeAddr is a mock implementation of net.Addr interface to simplify testing. -type fakeAddr struct { - // Addr is embedded here simply to make fakeAddr a net.Addr without - // actually implementing all methods. - net.Addr -} - -func TestIPFromAddr(t *testing.T) { - supIPv4 := net.IP{1, 2, 3, 4} - supIPv6 := net.ParseIP("2a00:1450:400c:c06::93") - - testCases := []struct { - name string - addr net.Addr - want net.IP - }{{ - name: "ipv4_tcp", - addr: &net.TCPAddr{ - IP: supIPv4, - }, - want: supIPv4, - }, { - name: "ipv6_tcp", - addr: &net.TCPAddr{ - IP: supIPv6, - }, - want: supIPv6, - }, { - name: "ipv4_udp", - addr: &net.UDPAddr{ - IP: supIPv4, - }, - want: supIPv4, - }, { - name: "ipv6_udp", - addr: &net.UDPAddr{ - IP: supIPv6, - }, - want: supIPv6, - }, { - name: "non-ip_addr", - addr: &fakeAddr{}, - want: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, IPFromAddr(tc.addr)) - }) - } -} diff --git a/internal/home/clients.go b/internal/home/clients.go index 5ccb73ff..30c508e4 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) { return c, true } -// FindUpstreams looks for upstreams configured for the client -// If no client found for this IP, or if no custom upstreams are configured, -// this method returns nil -func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig { +// findUpstreams returns upstreams configured for the client, identified either +// by its IP address or its ClientID. upsConf is nil if the client isn't found +// or if the client has no custom upstreams. +func (clients *clientsContainer) findUpstreams( + id string, +) (upsConf *proxy.UpstreamConfig, err error) { clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findLocked(ip) + c, ok := clients.findLocked(id) if !ok { - return nil + return nil, nil } upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty) if len(upstreams) == 0 { - return nil + return nil, nil } - if c.upstreamConfig == nil { - conf, err := proxy.ParseUpstreamsConfig( - upstreams, - upstream.Options{ - Bootstrap: config.DNS.BootstrapDNS, - Timeout: dnsforward.DefaultTimeout, - }, - ) - if err == nil { - c.upstreamConfig = &conf - } + if c.upstreamConfig != nil { + return c.upstreamConfig, nil } - return c.upstreamConfig + var conf proxy.UpstreamConfig + conf, err = proxy.ParseUpstreamsConfig( + upstreams, + upstream.Options{ + Bootstrap: config.DNS.BootstrapDNS, + Timeout: dnsforward.DefaultTimeout, + }, + ) + if err != nil { + return nil, err + } + + c.upstreamConfig = &conf + + return &conf, nil } // findLocked searches for a client by its ID. For internal use only. diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index aa040a58..f8c34fbb 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -25,7 +25,7 @@ func TestClients(t *testing.T) { } ok, err := clients.Add(c) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) c = &Client{ @@ -34,7 +34,7 @@ func TestClients(t *testing.T) { } ok, err = clients.Add(c) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) c, ok = clients.Find("1.1.1.1") @@ -59,7 +59,7 @@ func TestClients(t *testing.T) { IDs: []string{"1.2.3.5"}, Name: "client1", }) - require.Nil(t, err) + require.NoError(t, err) assert.False(t, ok) }) @@ -68,7 +68,7 @@ func TestClients(t *testing.T) { IDs: []string{"2.2.2.2"}, Name: "client3", }) - require.NotNil(t, err) + require.Error(t, err) assert.False(t, ok) }) @@ -77,13 +77,13 @@ func TestClients(t *testing.T) { IDs: []string{"1.2.3.0"}, Name: "client3", }) - require.NotNil(t, err) + require.Error(t, err) err = clients.Update("client3", &Client{ IDs: []string{"1.2.3.0"}, Name: "client2", }) - assert.NotNil(t, err) + assert.Error(t, err) }) t.Run("update_fail_ip", func(t *testing.T) { @@ -91,7 +91,7 @@ func TestClients(t *testing.T) { IDs: []string{"2.2.2.2"}, Name: "client1", }) - assert.NotNil(t, err) + assert.Error(t, err) }) t.Run("update_success", func(t *testing.T) { @@ -99,7 +99,7 @@ func TestClients(t *testing.T) { IDs: []string{"1.1.1.2"}, Name: "client1", }) - require.Nil(t, err) + require.NoError(t, err) assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) @@ -109,7 +109,7 @@ func TestClients(t *testing.T) { Name: "client1-renamed", UseOwnSettings: true, }) - require.Nil(t, err) + require.NoError(t, err) c, ok := clients.Find("1.1.1.2") require.True(t, ok) @@ -137,15 +137,15 @@ func TestClients(t *testing.T) { t.Run("addhost_success", func(t *testing.T) { ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) @@ -153,7 +153,7 @@ func TestClients(t *testing.T) { t.Run("addhost_fail", func(t *testing.T) { ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) - require.Nil(t, err) + require.NoError(t, err) assert.False(t, ok) }) } @@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) { t.Run("existing_auto-client", func(t *testing.T) { ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) clients.SetWhoisInfo("1.1.1.1", whois) @@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) { IDs: []string{"1.1.1.2"}, Name: "client1", }) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) clients.SetWhoisInfo("1.1.1.2", whois) @@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) { IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, Name: "client1", }) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) // Now add an auto-client with the same IP. ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) }) @@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) { Hostname: "testhost", Expiry: time.Now().Add(time.Hour), }) - require.Nil(t, err) + require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. ok, err := clients.Add(&Client{ IDs: []string{testIP.String()}, Name: "client2", }) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) // Add a new client with the IP from the first client's IP @@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) { IDs: []string{"2.2.2.2"}, Name: "client3", }) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) }) } @@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) { "[/example.org/]8.8.8.8", }, }) - require.Nil(t, err) + require.NoError(t, err) assert.True(t, ok) - config := clients.FindUpstreams("1.2.3.4") + config, err := clients.findUpstreams("1.2.3.4") assert.Nil(t, config) + assert.NoError(t, err) - config = clients.FindUpstreams("1.1.1.1") + config, err = clients.findUpstreams("1.1.1.1") require.NotNil(t, config) + assert.NoError(t, err) assert.Len(t, config.Upstreams, 1) assert.Len(t, config.DomainReservedUpstreams, 1) } diff --git a/internal/home/dns.go b/internal/home/dns.go index 82ed4fb2..8f4741d3 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strconv" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" @@ -106,7 +107,7 @@ func isRunning() bool { } func onDNSRequest(d *proxy.DNSContext) { - ip := dnsforward.IPFromAddr(d.Addr) + ip := aghnet.IPFromAddr(d.Addr) if ip == nil { // This would be quite weird if we get here. return @@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newConf.FilterHandler = applyAdditionalFiltering - newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams + newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams newConf.ResolveClients = dnsConf.ResolveClients newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS diff --git a/internal/home/whois.go b/internal/home/whois.go index e5b0ca08..bb7bc0e8 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -10,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -66,19 +67,6 @@ func trimValue(s string) string { return s[:maxValueLength-3] + "..." } -// coalesceStr returns the first non-empty string. -// -// TODO(a.garipov): Move to aghstrings? -func coalesceStr(strs ...string) (res string) { - for _, s := range strs { - if s != "" { - return s - } - } - - return "" -} - // isWhoisComment returns true if the string is empty or is a WHOIS comment. func isWhoisComment(s string) (ok bool) { return len(s) == 0 || s[0] == '#' || s[0] == '%' @@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) { v = trimValue(v) case "descr", "netname": k = "orgname" - v = coalesceStr(orgname, v) + v = aghstrings.Coalesce(orgname, v) orgname = v case "whois": k = "whois"