Pull request: 2846 cover aghnet vol.1

Merge in DNS/adguard-home from 2846-cover-aghnet-vol.1 to master

Updates #2846.

Squashed commit of the following:

commit 368e75b0bacb290f9929b8a5a682b06f2d75df6a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Jan 21 19:11:59 2022 +0300

    aghnet: imp tests

commit 8bb3e2a1680fd30294f7c82693891ffb19474c6a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Jan 21 18:27:06 2022 +0300

    aghnet: rm unused test

commit 28d8e64880f845810d0af629e5d1f06b9bde5b28
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Jan 21 18:18:22 2022 +0300

    aghnet: cover with tests
This commit is contained in:
Eugene Burkov 2022-01-21 19:21:38 +03:00
parent f7ff02f3b1
commit 3f5605c42e
7 changed files with 259 additions and 160 deletions

View File

@ -19,7 +19,8 @@ import (
"github.com/insomniacslk/dhcp/iana" "github.com/insomniacslk/dhcp/iana"
) )
// defaultDiscoverTime is the // defaultDiscoverTime is the default timeout of checking another DHCP server
// response.
const defaultDiscoverTime = 3 * time.Second const defaultDiscoverTime = 3 * time.Second
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {

View File

@ -343,113 +343,93 @@ func TestHostsContainer(t *testing.T) {
testdata := os.DirFS("./testdata") testdata := os.DirFS("./testdata")
nRewrites := func(t *testing.T, res *urlfilter.DNSResult, n int) (rws []*rules.DNSRewrite) {
rewrites := res.DNSRewrites()
require.Len(t, rewrites, n)
for _, rewrite := range rewrites {
require.Equal(t, listID, rewrite.FilterListID)
rw := rewrite.DNSRewrite
require.NotNil(t, rw)
rws = append(rws, rw)
}
return rws
}
testCases := []struct { testCases := []struct {
testTail func(t *testing.T, res *urlfilter.DNSResult) want []*rules.DNSRewrite
name string name string
req urlfilter.DNSRequest req urlfilter.DNSRequest
}{{ }{{
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
Value: net.IPv4(1, 0, 0, 1),
RRType: dns.TypeA,
}, {
RCode: dns.RcodeSuccess,
Value: net.IP(append((&[15]byte{})[:], byte(1))),
RRType: dns.TypeAAAA,
}},
name: "simple", name: "simple",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "simplehost", Hostname: "simplehost",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
rws := nRewrites(t, res, 2)
v, ok := rws[0].Value.(net.IP)
require.True(t, ok)
assert.True(t, net.IP{1, 0, 0, 1}.Equal(v))
v, ok = rws[1].Value.(net.IP)
require.True(t, ok)
// It's ::1.
assert.True(t, net.IP(append((&[15]byte{})[:], byte(1))).Equal(v))
},
}, { }, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
NewCNAME: "hello",
}},
name: "hello_alias", name: "hello_alias",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "hello.world", Hostname: "hello.world",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
assert.Equal(t, "hello", nRewrites(t, res, 1)[0].NewCNAME)
},
}, { }, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
NewCNAME: "hello",
}},
name: "other_line_alias", name: "other_line_alias",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "hello.world.again", Hostname: "hello.world.again",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
assert.Equal(t, "hello", nRewrites(t, res, 1)[0].NewCNAME)
},
}, { }, {
want: []*rules.DNSRewrite{},
name: "hello_subdomain", name: "hello_subdomain",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "say.hello", Hostname: "say.hello",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
assert.Empty(t, res.DNSRewrites())
},
}, { }, {
want: []*rules.DNSRewrite{},
name: "hello_alias_subdomain", name: "hello_alias_subdomain",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "say.hello.world", Hostname: "say.hello.world",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
assert.Empty(t, res.DNSRewrites())
},
}, { }, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
NewCNAME: "a.whole",
}},
name: "lots_of_aliases", name: "lots_of_aliases",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "for.testing", Hostname: "for.testing",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
assert.Equal(t, "a.whole", nRewrites(t, res, 1)[0].NewCNAME)
},
}, { }, {
want: []*rules.DNSRewrite{{
RCode: dns.RcodeSuccess,
RRType: dns.TypePTR,
Value: "simplehost.",
}},
name: "reverse", name: "reverse",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa", Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypePTR, DNSType: dns.TypePTR,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) {
rws := nRewrites(t, res, 1)
assert.Equal(t, dns.TypePTR, rws[0].RRType)
assert.Equal(t, "simplehost.", rws[0].Value)
},
}, { }, {
want: []*rules.DNSRewrite{},
name: "non-existing", name: "non-existing",
req: urlfilter.DNSRequest{ req: urlfilter.DNSRequest{
Hostname: "nonexisting", Hostname: "nonexisting",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
testTail: func(t *testing.T, res *urlfilter.DNSResult) { }, {
require.NotNil(t, res) want: nil,
name: "bad_type",
assert.Nil(t, res.DNSRewrites()) req: urlfilter.DNSRequest{
Hostname: "1.0.0.1.in-addr.arpa",
DNSType: dns.TypeSRV,
}, },
}} }}
@ -466,9 +446,26 @@ func TestHostsContainer(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
res, ok := hc.MatchRequest(tc.req) res, ok := hc.MatchRequest(tc.req)
require.False(t, ok) require.False(t, ok)
if tc.want == nil {
assert.Nil(t, res)
return
}
require.NotNil(t, res) require.NotNil(t, res)
tc.testTail(t, res) rewrites := res.DNSRewrites()
require.Len(t, rewrites, len(tc.want))
for i, rewrite := range rewrites {
require.Equal(t, listID, rewrite.FilterListID)
rw := rewrite.DNSRewrite
require.NotNil(t, rw)
assert.Equal(t, tc.want[i], rw)
}
}) })
} }
} }

View File

@ -25,6 +25,13 @@ type NetIface interface {
// IfaceIPAddrs returns the interface's IP addresses. // IfaceIPAddrs returns the interface's IP addresses.
func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) { func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
switch ipv {
case IPVersion4, IPVersion6:
// Go on.
default:
return nil, fmt.Errorf("invalid ip version %d", ipv)
}
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
return nil, err return nil, err
@ -41,20 +48,16 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
continue continue
} }
// Assume that net.(*Interface).Addrs can only return valid IPv4 // Assume that net.(*Interface).Addrs can only return valid IPv4 and
// and IPv6 addresses. Thus, if it isn't an IPv4 address, it // IPv6 addresses. Thus, if it isn't an IPv4 address, it must be an
// must be an IPv6 one. // IPv6 one.
switch ipv { ip4 := ip.To4()
case IPVersion4: if ipv == IPVersion4 {
if ip4 := ip.To4(); ip4 != nil { if ip4 != nil {
ips = append(ips, ip4) ips = append(ips, ip4)
} }
case IPVersion6: } else if ip4 == nil {
if ip6 := ip.To4(); ip6 == nil { ips = append(ips, ip)
ips = append(ips, ip)
}
default:
return nil, fmt.Errorf("invalid ip version %d", ipv)
} }
} }
@ -96,16 +99,16 @@ func IfaceDNSIPAddrs(
switch len(addrs) { switch len(addrs) {
case 0: case 0:
// Don't return errors in case the users want to try and enable // Don't return errors in case the users want to try and enable the DHCP
// the DHCP server later. // server later.
t := time.Duration(n) * backoff t := time.Duration(n) * backoff
log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t) log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t)
return nil, nil return nil, nil
case 1: case 1:
// Some Android devices use 8.8.8.8 if there is not a secondary // Some Android devices use 8.8.8.8 if there is not a secondary DNS
// DNS server. Fix that by setting the secondary DNS address to // server. Fix that by setting the secondary DNS address to the same
// the same address. // address.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/1708. // See https://github.com/AdguardTeam/AdGuardHome/issues/1708.
log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv) log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv)

View File

@ -5,13 +5,15 @@ import (
"testing" "testing"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// fakeIface is a stub implementation of aghnet.NetIface to simplify testing.
type fakeIface struct { type fakeIface struct {
addrs []net.Addr
err error err error
addrs []net.Addr
} }
// Addrs implements the NetIface interface for *fakeIface. // Addrs implements the NetIface interface for *fakeIface.
@ -33,61 +35,86 @@ func TestIfaceIPAddrs(t *testing.T) {
addr6 := &net.IPNet{IP: ip6} addr6 := &net.IPNet{IP: ip6}
testCases := []struct { testCases := []struct {
name string iface NetIface
iface NetIface name string
ipv IPVersion wantErrMsg string
want []net.IP want []net.IP
wantErr error ipv IPVersion
}{{ }{{
name: "ipv4_success", iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, name: "ipv4_success",
ipv: IPVersion4, wantErrMsg: "",
want: []net.IP{ip4}, want: []net.IP{ip4},
wantErr: nil, ipv: IPVersion4,
}, { }, {
name: "ipv4_success_with_ipv6", iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, name: "ipv4_success_with_ipv6",
ipv: IPVersion4, wantErrMsg: "",
want: []net.IP{ip4}, want: []net.IP{ip4},
wantErr: nil, ipv: IPVersion4,
}, { }, {
name: "ipv4_error", iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest},
iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest}, name: "ipv4_error",
ipv: IPVersion4, wantErrMsg: errTest.Error(),
want: nil, want: nil,
wantErr: errTest, ipv: IPVersion4,
}, { }, {
name: "ipv6_success", iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil},
iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil}, name: "ipv6_success",
ipv: IPVersion6, wantErrMsg: "",
want: []net.IP{ip6}, want: []net.IP{ip6},
wantErr: nil, ipv: IPVersion6,
}, { }, {
name: "ipv6_success_with_ipv4", iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil},
iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, name: "ipv6_success_with_ipv4",
ipv: IPVersion6, wantErrMsg: "",
want: []net.IP{ip6}, want: []net.IP{ip6},
wantErr: nil, ipv: IPVersion6,
}, { }, {
name: "ipv6_error", iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest},
iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest}, name: "ipv6_error",
ipv: IPVersion6, wantErrMsg: errTest.Error(),
want: nil, want: nil,
wantErr: errTest, ipv: IPVersion6,
}, {
iface: &fakeIface{addrs: nil, err: nil},
name: "bad_proto",
wantErrMsg: "invalid ip version 10",
want: nil,
ipv: IPVersion6 + IPVersion4,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip4}}, err: nil},
name: "ipaddr_v4",
wantErrMsg: "",
want: []net.IP{ip4},
ipv: IPVersion4,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip6, Zone: ""}}, err: nil},
name: "ipaddr_v6",
wantErrMsg: "",
want: []net.IP{ip6},
ipv: IPVersion6,
}, {
iface: &fakeIface{addrs: []net.Addr{&net.UnixAddr{}}, err: nil},
name: "non-ipv4",
wantErrMsg: "",
want: nil,
ipv: IPVersion4,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, gotErr := IfaceIPAddrs(tc.iface, tc.ipv) got, err := IfaceIPAddrs(tc.iface, tc.ipv)
require.True(t, errors.Is(gotErr, tc.wantErr)) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got) assert.Equal(t, tc.want, got)
}) })
} }
} }
type waitingFakeIface struct { type waitingFakeIface struct {
addrs []net.Addr
err error err error
addrs []net.Addr
n int n int
} }
@ -116,11 +143,11 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
addr6 := &net.IPNet{IP: ip6} addr6 := &net.IPNet{IP: ip6}
testCases := []struct { testCases := []struct {
name string
iface NetIface iface NetIface
ipv IPVersion
want []net.IP
wantErr error wantErr error
name string
want []net.IP
ipv IPVersion
}{{ }{{
name: "ipv4_success", name: "ipv4_success",
iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil},
@ -169,12 +196,25 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
ipv: IPVersion6, ipv: IPVersion6,
want: []net.IP{ip6, ip6}, want: []net.IP{ip6, ip6},
wantErr: nil, wantErr: nil,
}, {
name: "empty",
iface: &fakeIface{addrs: nil, err: nil},
ipv: IPVersion4,
want: nil,
wantErr: nil,
}, {
name: "many",
iface: &fakeIface{addrs: []net.Addr{addr4, addr4}},
ipv: IPVersion4,
want: []net.IP{ip4, ip4},
wantErr: nil,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, gotErr := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) got, err := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
require.True(t, errors.Is(gotErr, tc.wantErr)) require.ErrorIs(t, err, tc.wantErr)
assert.Equal(t, tc.want, got) assert.Equal(t, tc.want, got)
}) })
} }

View File

@ -0,0 +1,44 @@
package aghnet
import (
"net"
"testing"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert"
)
func TestIPMut(t *testing.T) {
testIPs := []net.IP{{
127, 0, 0, 1,
}, {
192, 168, 0, 1,
}, {
8, 8, 8, 8,
}}
t.Run("nil_no_mut", func(t *testing.T) {
ipmut := NewIPMut(nil)
ips := netutil.CloneIPs(testIPs)
for i := range ips {
ipmut.Load()(ips[i])
assert.True(t, ips[i].Equal(testIPs[i]))
}
})
t.Run("not_nil_mut", func(t *testing.T) {
ipmut := NewIPMut(func(ip net.IP) {
for i := range ip {
ip[i] = 0
}
})
want := netutil.IPv4Zero()
ips := netutil.CloneIPs(testIPs)
for i := range ips {
ipmut.Load()(ips[i])
assert.True(t, ips[i].Equal(want))
}
})
}

View File

@ -42,8 +42,7 @@ func GatewayIP(ifaceName string) net.IP {
fields := strings.Fields(string(d)) fields := strings.Fields(string(d))
// The meaningful "ip route" command output should contain the word // The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third // "default" at first field and default gateway IP address at third field.
// field.
if len(fields) < 3 || fields[0] != "default" { if len(fields) < 3 || fields[0] != "default" {
return nil return nil
} }
@ -218,28 +217,6 @@ func IsAddrInUse(err error) (ok bool) {
return isAddrInUse(sysErr) return isAddrInUse(sysErr)
} }
// SplitHost is a wrapper for net.SplitHostPort for the cases when the hostport
// does not necessarily contain a port.
func SplitHost(hostport string) (host string, err error) {
host, _, err = net.SplitHostPort(hostport)
if err != nil {
// Check for the missing port error. If it is that error, just
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"
addrErr := &net.AddrError{}
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
return "", err
}
host = hostport
}
return host, nil
}
// CollectAllIfacesAddrs returns the slice of all network interfaces IP // CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number. // addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) { func CollectAllIfacesAddrs() (addrs []string, err error) {

View File

@ -15,12 +15,20 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
func TestGetValidNetInterfacesForWeb(t *testing.T) { func TestGetInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
require.NoErrorf(t, err, "cannot get net interfaces: %s", err) require.NoError(t, err)
require.NotEmpty(t, ifaces, "no net interfaces found") require.NotEmpty(t, ifaces)
for _, iface := range ifaces { for _, iface := range ifaces {
require.NotEmptyf(t, iface.Addresses, "no addresses found for %s", iface.Name) t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := GetInterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
} }
} }
@ -73,18 +81,47 @@ func TestBroadcastFromIPNet(t *testing.T) {
} }
func TestCheckPort(t *testing.T) { func TestCheckPort(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:") t.Run("tcp_bound", func(t *testing.T) {
require.NoError(t, err) l, err := net.Listen("tcp", "127.0.0.1:")
testutil.CleanupAndRequireSuccess(t, l.Close) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := netutil.IPPortFromAddr(l.Addr()) ipp := netutil.IPPortFromAddr(l.Addr())
require.NotNil(t, ipp) require.NotNil(t, ipp)
require.NotNil(t, ipp.IP) require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port) require.NotZero(t, ipp.Port)
err = CheckPort("tcp", ipp.IP, ipp.Port) err = CheckPort("tcp", ipp.IP, ipp.Port)
target := &net.OpError{} target := &net.OpError{}
require.ErrorAs(t, err, &target) require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op) assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", "127.0.0.1:")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := netutil.IPPortFromAddr(conn.LocalAddr())
require.NotNil(t, ipp)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
err = CheckPort("udp", ipp.IP, ipp.Port)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", nil, 0)
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0)
assert.NoError(t, err)
})
} }