Pull request: all: fix client upstreams, imp code

Updates #3186.

Squashed commit of the following:

commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri May 28 12:54:07 2021 +0300

    all: imp code, names

commit 98f86c21ae23b665095075feb4a59dcfcc622bc7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu May 27 21:11:37 2021 +0300

    all: fix client upstreams, imp code
This commit is contained in:
Ainar Garipov 2021-05-28 13:02:59 +03:00
parent 48b8579703
commit 3be783bd34
18 changed files with 249 additions and 270 deletions

View File

@ -32,6 +32,8 @@ released by then.
### Fixed ### Fixed
- Custom upstreams selection for clients with client IDs in DNS-over-TLS and
DNS-over-HTTP ([#3186]).
- Incorrect client-based filtering applying logic ([#2875]). - Incorrect client-based filtering applying logic ([#2875]).
### Removed ### Removed
@ -40,6 +42,7 @@ released by then.
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184 [#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
[#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185 [#3185]: https://github.com/AdguardTeam/AdGuardHome/issues/3185
[#3186]: https://github.com/AdguardTeam/AdGuardHome/issues/3186

View File

@ -10,6 +10,19 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
// IPFromAddr returns an IP address from addr. If addr is neither
// a *net.TCPAddr nor a *net.UDPAddr, it returns nil.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.TCPAddr:
return addr.IP
case *net.UDPAddr:
return addr.IP
}
return nil
}
// IsValidHostOuterRune returns true if r is a valid initial or final rune for // IsValidHostOuterRune returns true if r is a valid initial or final rune for
// a hostname label. // a hostname label.
func IsValidHostOuterRune(r rune) (ok bool) { func IsValidHostOuterRune(r rune) (ok bool) {

View File

@ -9,6 +9,14 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestIPFromAddr(t *testing.T) {
ip := net.IP{1, 2, 3, 4}
assert.Equal(t, net.IP(nil), IPFromAddr(nil))
assert.Equal(t, net.IP(nil), IPFromAddr(struct{ net.Addr }{}))
assert.Equal(t, ip, IPFromAddr(&net.TCPAddr{IP: ip}))
assert.Equal(t, ip, IPFromAddr(&net.UDPAddr{IP: ip}))
}
func TestValidateHardwareAddress(t *testing.T) { func TestValidateHardwareAddress(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string

View File

@ -19,6 +19,19 @@ func CloneSlice(a []string) (b []string) {
return CloneSliceOrEmpty(a) return CloneSliceOrEmpty(a)
} }
// Coalesce returns the first non-empty string. It is named after the function
// COALESCE in SQL except that since strings in Go are non-nullable, it uses an
// empty string as a NULL value. If strs is empty, it returns an empty string.
func Coalesce(strs ...string) (res string) {
for _, s := range strs {
if s != "" {
return s
}
}
return ""
}
// FilterOut returns a copy of strs with all strings for which f returned true // FilterOut returns a copy of strs with all strings for which f returned true
// removed. // removed.
func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) { func FilterOut(strs []string, f func(s string) (ok bool)) (filtered []string) {

View File

@ -36,6 +36,14 @@ func TestCloneSlice_family(t *testing.T) {
}) })
} }
func TestCoalesce(t *testing.T) {
assert.Equal(t, "", Coalesce())
assert.Equal(t, "a", Coalesce("a"))
assert.Equal(t, "a", Coalesce("", "a"))
assert.Equal(t, "a", Coalesce("a", ""))
assert.Equal(t, "a", Coalesce("a", "b"))
}
func TestFilterOut(t *testing.T) { func TestFilterOut(t *testing.T) {
strs := []string{ strs := []string{
"1.2.3.4", "1.2.3.4",

View File

@ -8,7 +8,9 @@ import (
"net/http" "net/http"
"os" "os"
"sort" "sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -27,11 +29,10 @@ type FilteringConfig struct {
// FilterHandler is an optional additional filtering callback. // FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"` FilterHandler func(clientAddr net.IP, clientID string, settings *filtering.Settings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration // GetCustomUpstreamByClient is a callback that returns upstreams
// based on the client IP address. Returns nil if there are no custom upstreams for the client // configuration based on the client IP address or ClientID. It returns
// // nil if there are no custom upstreams for the client.
// TODO(e.burkov): Replace argument type with net.IP. GetCustomUpstreamByClient func(id string) (conf *proxy.UpstreamConfig, err error) `yaml:"-"`
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration // Protection configuration
// -- // --
@ -384,10 +385,51 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
return nil return nil
} }
// isInSorted returns true if s is in the sorted slice strs.
func isInSorted(strs []string, s string) (ok bool) {
i := sort.SearchStrings(strs, s)
if i == len(strs) || strs[i] != s {
return false
}
return true
}
// isWildcard returns true if host is a wildcard hostname.
func isWildcard(host string) (ok bool) {
return len(host) >= 2 && host[0] == '*' && host[1] == '.'
}
// matchesDomainWildcard returns true if host matches the domain wildcard
// pattern pat.
func matchesDomainWildcard(host, pat string) (ok bool) {
return isWildcard(pat) && strings.HasSuffix(host, pat[1:])
}
// anyNameMatches returns true if sni, the client's SNI value, matches any of
// the DNS names and patterns from certificate. dnsNames must be sorted.
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
if aghnet.ValidateDomainName(sni) != nil {
return false
}
if isInSorted(dnsNames, sni) {
return true
}
for _, dn := range dnsNames {
if matchesDomainWildcard(sni, dn) {
return true
}
}
return false
}
// Called by 'tls' package when Client Hello is received // Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. // If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.StrictSNICheck && !matchDNSName(s.conf.dnsNames, ch.ServerName) { if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName) log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
return nil, fmt.Errorf("invalid SNI") return nil, fmt.Errorf("invalid SNI")
} }

View File

@ -0,0 +1,53 @@
package dnsforward
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAnyNameMatches(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, anyNameMatches(dnsNames, tc.dnsName))
})
}
}

View File

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -229,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess rc = resultCodeSuccess
var ip net.IP var ip net.IP
if ip = IPFromAddr(dctx.proxyCtx.Addr); ip == nil { if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc return rc
} }
@ -489,6 +490,15 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// ipStringFromAddr extracts an IP address string from net.Addr.
func ipStringFromAddr(addr net.Addr) (ipStr string) {
if ip := aghnet.IPFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
// processUpstream passes request to upstream servers and handles the response. // processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) { func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx d := ctx.proxyCtx
@ -497,9 +507,13 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) {
} }
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := IPStringFromAddr(d.Addr) // Use the clientID first, since it has a higher priority.
if upsConf := s.conf.GetCustomUpstreamByClient(clientIP); upsConf != nil { id := aghstrings.Coalesce(ctx.clientID, ipStringFromAddr(d.Addr))
log.Debug("dns: using custom upstreams for client %s", clientIP) upsConf, err := s.conf.GetCustomUpstreamByClient(id)
if err != nil {
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
} else if upsConf != nil {
log.Debug("dns: using custom upstreams for client %s", id)
d.CustomUpstreamConfig = upsConf d.CustomUpstreamConfig = upsConf
} }
} }

View File

@ -379,3 +379,18 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
require.Empty(t, proxyCtx.Res.Answer) require.Empty(t, proxyCtx.Res.Answer)
}) })
} }
func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, ipStringFromAddr(&addr), addr.IP.String())
})
t.Run("nil", func(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil))
})
}

View File

@ -12,7 +12,6 @@ import (
"math/big" "math/big"
"net" "net"
"os" "os"
"sort"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -521,16 +520,16 @@ func TestServerCustomClientUpstream(t *testing.T) {
}, },
} }
s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
return &proxy.UpstreamConfig{ ups := &aghtest.TestUpstream{
Upstreams: []upstream.Upstream{ IPv4: map[string][]net.IP{
&aghtest.TestUpstream{ "host.": {{192, 168, 0, 1}},
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
},
}, },
} }
return &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
}, nil
} }
startDeferStop(t, s) startDeferStop(t, s)
@ -962,65 +961,6 @@ func publicKey(priv interface{}) interface{} {
} }
} }
func TestIPStringFromAddr(t *testing.T) {
t.Run("not_nil", func(t *testing.T) {
addr := net.UDPAddr{
IP: net.ParseIP("1:2:3::4"),
Port: 12345,
Zone: "eth0",
}
assert.Equal(t, IPStringFromAddr(&addr), addr.IP.String())
})
t.Run("nil", func(t *testing.T) {
assert.Empty(t, IPStringFromAddr(nil))
})
}
func TestMatchDNSName(t *testing.T) {
dnsNames := []string{"host1", "*.host2", "1.2.3.4"}
sort.Strings(dnsNames)
testCases := []struct {
name string
dnsName string
want bool
}{{
name: "match",
dnsName: "host1",
want: true,
}, {
name: "match",
dnsName: "a.host2",
want: true,
}, {
name: "match",
dnsName: "b.a.host2",
want: true,
}, {
name: "match",
dnsName: "1.2.3.4",
want: true,
}, {
name: "mismatch",
dnsName: "host2",
want: false,
}, {
name: "mismatch",
dnsName: "",
want: false,
}, {
name: "mismatch",
dnsName: "*.host2",
want: false,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, matchDNSName(dnsNames, tc.dnsName))
})
}
}
type testDHCP struct{} type testDHCP struct{}
func (d *testDHCP) Enabled() (ok bool) { return true } func (d *testDHCP) Enabled() (ok bool) { return true }

View File

@ -4,15 +4,15 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPFromAddr(d.Addr) ip := aghnet.IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip) disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed { if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip) log.Tracef("Client IP %s is blocked by settings", ip)
@ -39,7 +39,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig() setts := s.dnsFilter.GetConfig()
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts) s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
} }
return &setts return &setts

View File

@ -4,6 +4,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
@ -37,7 +38,7 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
OrigAnswer: ctx.origResp, OrigAnswer: ctx.origResp,
Result: ctx.result, Result: ctx.result,
Elapsed: elapsed, Elapsed: elapsed,
ClientIP: IPFromAddr(pctx.Addr), ClientIP: aghnet.IPFromAddr(pctx.Addr),
ClientID: ctx.clientID, ClientID: ctx.clientID,
} }
@ -79,7 +80,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
if clientID := ctx.clientID; clientID != "" { if clientID := ctx.clientID; clientID != "" {
e.Client = clientID e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil { } else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String() e.Client = ip.String()
} }

View File

@ -1,69 +0,0 @@
package dnsforward
import (
"net"
"sort"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
)
// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
case *net.TCPAddr:
return addr.IP
}
return nil
}
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := IPFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
// Find value in a sorted array
func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val)
if i == len(ar) || ar[i] != val {
return -1
}
return i
}
func isWildcard(host string) bool {
return len(host) >= 2 &&
host[0] == '*' && host[1] == '.'
}
// Return TRUE if host name matches a wildcard pattern
func matchDomainWildcard(host, wildcard string) bool {
return isWildcard(wildcard) &&
strings.HasSuffix(host, wildcard[1:])
}
// Return TRUE if client's SNI value matches DNS names from certificate
func matchDNSName(dnsNames []string, sni string) bool {
if aghnet.ValidateDomainName(sni) != nil {
return false
}
if findSorted(dnsNames, sni) != -1 {
return true
}
for _, dn := range dnsNames {
if matchDomainWildcard(sni, dn) {
return true
}
}
return false
}

View File

@ -1,60 +0,0 @@
package dnsforward
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
// fakeAddr is a mock implementation of net.Addr interface to simplify testing.
type fakeAddr struct {
// Addr is embedded here simply to make fakeAddr a net.Addr without
// actually implementing all methods.
net.Addr
}
func TestIPFromAddr(t *testing.T) {
supIPv4 := net.IP{1, 2, 3, 4}
supIPv6 := net.ParseIP("2a00:1450:400c:c06::93")
testCases := []struct {
name string
addr net.Addr
want net.IP
}{{
name: "ipv4_tcp",
addr: &net.TCPAddr{
IP: supIPv4,
},
want: supIPv4,
}, {
name: "ipv6_tcp",
addr: &net.TCPAddr{
IP: supIPv6,
},
want: supIPv6,
}, {
name: "ipv4_udp",
addr: &net.UDPAddr{
IP: supIPv4,
},
want: supIPv4,
}, {
name: "ipv6_udp",
addr: &net.UDPAddr{
IP: supIPv6,
},
want: supIPv6,
}, {
name: "non-ip_addr",
addr: &fakeAddr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, IPFromAddr(tc.addr))
})
}
}

View File

@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return c, true return c, true
} }
// FindUpstreams looks for upstreams configured for the client // findUpstreams returns upstreams configured for the client, identified either
// If no client found for this IP, or if no custom upstreams are configured, // by its IP address or its ClientID. upsConf is nil if the client isn't found
// this method returns nil // or if the client has no custom upstreams.
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig { func (clients *clientsContainer) findUpstreams(
id string,
) (upsConf *proxy.UpstreamConfig, err error) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
c, ok := clients.findLocked(ip) c, ok := clients.findLocked(id)
if !ok { if !ok {
return nil return nil, nil
} }
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty) upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
if len(upstreams) == 0 { if len(upstreams) == 0 {
return nil return nil, nil
} }
if c.upstreamConfig == nil { if c.upstreamConfig != nil {
conf, err := proxy.ParseUpstreamsConfig( return c.upstreamConfig, nil
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err == nil {
c.upstreamConfig = &conf
}
} }
return c.upstreamConfig var conf proxy.UpstreamConfig
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err != nil {
return nil, err
}
c.upstreamConfig = &conf
return &conf, nil
} }
// findLocked searches for a client by its ID. For internal use only. // findLocked searches for a client by its ID. For internal use only.

View File

@ -25,7 +25,7 @@ func TestClients(t *testing.T) {
} }
ok, err := clients.Add(c) ok, err := clients.Add(c)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
c = &Client{ c = &Client{
@ -34,7 +34,7 @@ func TestClients(t *testing.T) {
} }
ok, err = clients.Add(c) ok, err = clients.Add(c)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
c, ok = clients.Find("1.1.1.1") c, ok = clients.Find("1.1.1.1")
@ -59,7 +59,7 @@ func TestClients(t *testing.T) {
IDs: []string{"1.2.3.5"}, IDs: []string{"1.2.3.5"},
Name: "client1", Name: "client1",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.False(t, ok) assert.False(t, ok)
}) })
@ -68,7 +68,7 @@ func TestClients(t *testing.T) {
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client3", Name: "client3",
}) })
require.NotNil(t, err) require.Error(t, err)
assert.False(t, ok) assert.False(t, ok)
}) })
@ -77,13 +77,13 @@ func TestClients(t *testing.T) {
IDs: []string{"1.2.3.0"}, IDs: []string{"1.2.3.0"},
Name: "client3", Name: "client3",
}) })
require.NotNil(t, err) require.Error(t, err)
err = clients.Update("client3", &Client{ err = clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"}, IDs: []string{"1.2.3.0"},
Name: "client2", Name: "client2",
}) })
assert.NotNil(t, err) assert.Error(t, err)
}) })
t.Run("update_fail_ip", func(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) {
@ -91,7 +91,7 @@ func TestClients(t *testing.T) {
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client1", Name: "client1",
}) })
assert.NotNil(t, err) assert.Error(t, err)
}) })
t.Run("update_success", func(t *testing.T) { t.Run("update_success", func(t *testing.T) {
@ -99,7 +99,7 @@ func TestClients(t *testing.T) {
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
Name: "client1", Name: "client1",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
@ -109,7 +109,7 @@ func TestClients(t *testing.T) {
Name: "client1-renamed", Name: "client1-renamed",
UseOwnSettings: true, UseOwnSettings: true,
}) })
require.Nil(t, err) require.NoError(t, err)
c, ok := clients.Find("1.1.1.2") c, ok := clients.Find("1.1.1.2")
require.True(t, ok) require.True(t, ok)
@ -137,15 +137,15 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) { t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
@ -153,7 +153,7 @@ func TestClients(t *testing.T) {
t.Run("addhost_fail", func(t *testing.T) { t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
require.Nil(t, err) require.NoError(t, err)
assert.False(t, ok) assert.False(t, ok)
}) })
} }
@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) { t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.1", whois) clients.SetWhoisInfo("1.1.1.1", whois)
@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) {
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
Name: "client1", Name: "client1",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.2", whois) clients.SetWhoisInfo("1.1.1.2", whois)
@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) {
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1", Name: "client1",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
// Now add an auto-client with the same IP. // Now add an auto-client with the same IP.
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
}) })
@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) {
Hostname: "testhost", Hostname: "testhost",
Expiry: time.Now().Add(time.Hour), Expiry: time.Now().Add(time.Hour),
}) })
require.Nil(t, err) require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC. // Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{
IDs: []string{testIP.String()}, IDs: []string{testIP.String()},
Name: "client2", Name: "client2",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
// Add a new client with the IP from the first client's IP // Add a new client with the IP from the first client's IP
@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) {
IDs: []string{"2.2.2.2"}, IDs: []string{"2.2.2.2"},
Name: "client3", Name: "client3",
}) })
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
}) })
} }
@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) {
"[/example.org/]8.8.8.8", "[/example.org/]8.8.8.8",
}, },
}) })
require.Nil(t, err) require.NoError(t, err)
assert.True(t, ok) assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4") config, err := clients.findUpstreams("1.2.3.4")
assert.Nil(t, config) assert.Nil(t, config)
assert.NoError(t, err)
config = clients.FindUpstreams("1.1.1.1") config, err = clients.findUpstreams("1.1.1.1")
require.NotNil(t, config) require.NotNil(t, config)
assert.NoError(t, err)
assert.Len(t, config.Upstreams, 1) assert.Len(t, config.Upstreams, 1)
assert.Len(t, config.DomainReservedUpstreams, 1) assert.Len(t, config.DomainReservedUpstreams, 1)
} }

View File

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
@ -106,7 +107,7 @@ func isRunning() bool {
} }
func onDNSRequest(d *proxy.DNSContext) { func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.IPFromAddr(d.Addr) ip := aghnet.IPFromAddr(d.Addr)
if ip == nil { if ip == nil {
// This would be quite weird if we get here. // This would be quite weird if we get here.
return return
@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newConf.FilterHandler = applyAdditionalFiltering newConf.FilterHandler = applyAdditionalFiltering
newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
newConf.ResolveClients = dnsConf.ResolveClients newConf.ResolveClients = dnsConf.ResolveClients
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -66,19 +67,6 @@ func trimValue(s string) string {
return s[:maxValueLength-3] + "..." return s[:maxValueLength-3] + "..."
} }
// coalesceStr returns the first non-empty string.
//
// TODO(a.garipov): Move to aghstrings?
func coalesceStr(strs ...string) (res string) {
for _, s := range strs {
if s != "" {
return s
}
}
return ""
}
// isWhoisComment returns true if the string is empty or is a WHOIS comment. // isWhoisComment returns true if the string is empty or is a WHOIS comment.
func isWhoisComment(s string) (ok bool) { func isWhoisComment(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' || s[0] == '%' return len(s) == 0 || s[0] == '#' || s[0] == '%'
@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) {
v = trimValue(v) v = trimValue(v)
case "descr", "netname": case "descr", "netname":
k = "orgname" k = "orgname"
v = coalesceStr(orgname, v) v = aghstrings.Coalesce(orgname, v)
orgname = v orgname = v
case "whois": case "whois":
k = "whois" k = "whois"