Pull request: 2508 ip conversion vol.1

Merge in DNS/adguard-home from 2508-ip-conversion to master

Updates #2508.

Squashed commit of the following:

commit 3f64709fbc73ef74c11b910997be1e9bc337193c
Merge: 5ac7faaaa 0d67aa251
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 16:21:34 2021 +0300

    Merge branch 'master' into 2508-ip-conversion

commit 5ac7faaaa9dda570fdb872acad5d13d078f46b64
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 12:00:11 2021 +0300

    all: replace conditions with appropriate functions in tests

commit 9e3fa9a115ed23024c57dd5192d5173477ddbf71
Merge: db992a42a bba74859e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 10:47:10 2021 +0300

    Merge branch 'master' into 2508-ip-conversion

commit db992a42a2c6f315421e78a6a0492e2bfb3ce89d
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 18:55:53 2021 +0300

    sysutil: fix linux tests

commit f629b15d62349323ce2da05e68dc9cc0b5f6e194
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 18:41:20 2021 +0300

    all: improve code quality

commit 3bf03a75524040738562298bd1de6db536af130f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 17:33:26 2021 +0300

    sysutil: fix linux net.IP conversion

commit 5d5b6994916923636e635588631b63b7e7b74e5f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 14:57:26 2021 +0300

    dnsforward: remove redundant net.IP <-> string conversion

commit 0b955d99b7fad40942f21d1dd8734adb99126195
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 11 18:04:25 2021 +0300

    dhcpd: remove net.IP <-> string conversion
This commit is contained in:
Eugene Burkov 2021-01-13 16:56:05 +03:00
parent 0d67aa251d
commit e8c1f5c8d3
39 changed files with 409 additions and 435 deletions

View File

@ -28,27 +28,27 @@ func TestDB(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: testNotify,
}
s.srv4, err = v4Create(conf)
assert.True(t, err == nil)
assert.Nil(t, err)
s.srv6, err = v6Create(V6ServerConf{})
assert.True(t, err == nil)
assert.Nil(t, err)
l := Lease{}
l.IP = net.ParseIP("192.168.10.100").To4()
l.IP = net.IP{192, 168, 10, 100}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
exp1 := time.Now().Add(time.Hour)
l.Expiry = exp1
s.srv4.(*v4Server).addLease(&l)
l2 := Lease{}
l2.IP = net.ParseIP("192.168.10.101").To4()
l2.IP = net.IP{192, 168, 10, 101}
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
s.srv4.AddStaticLease(l2)
@ -62,7 +62,7 @@ func TestDB(t *testing.T) {
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix())
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
@ -75,8 +75,8 @@ func TestIsValidSubnetMask(t *testing.T) {
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1}))
assert.False(t, isValidSubnetMask([]byte{255, 255, 253, 0}))
assert.False(t, isValidSubnetMask([]byte{255, 255, 255, 1}))
}
func TestNormalizeLeases(t *testing.T) {
@ -100,7 +100,7 @@ func TestNormalizeLeases(t *testing.T) {
leases := normalizeLeases(staticLeases, dynLeases)
assert.True(t, len(leases) == 3)
assert.Len(t, leases, 3)
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
@ -109,22 +109,22 @@ func TestNormalizeLeases(t *testing.T) {
func TestOptions(t *testing.T) {
code, val := parseOptionString(" 12 hex abcdef ")
assert.Equal(t, uint8(12), code)
assert.EqualValues(t, 12, code)
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
code, _ = parseOptionString(" 12 hex abcdef1 ")
assert.Equal(t, uint8(0), code)
assert.EqualValues(t, 0, code)
code, val = parseOptionString("123 ip 1.2.3.4")
assert.Equal(t, uint8(123), code)
assert.EqualValues(t, 123, code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
code, _ = parseOptionString("256 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
assert.EqualValues(t, 0, code)
code, _ = parseOptionString("-1 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
assert.EqualValues(t, 0, code)
code, _ = parseOptionString("12 ip 1.1.1.1x")
assert.Equal(t, uint8(0), code)
assert.EqualValues(t, 0, code)
code, _ = parseOptionString("12 x 1.1.1.1")
assert.Equal(t, uint8(0), code)
assert.EqualValues(t, 0, code)
}

View File

@ -42,10 +42,10 @@ func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string
}
type v4ServerConfJSON struct {
GatewayIP string `json:"gateway_ip"`
SubnetMask string `json:"subnet_mask"`
RangeStart string `json:"range_start"`
RangeEnd string `json:"range_end"`
GatewayIP net.IP `json:"gateway_ip"`
SubnetMask net.IP `json:"subnet_mask"`
RangeStart net.IP `json:"range_start"`
RangeEnd net.IP `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"`
}
@ -61,10 +61,10 @@ func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
return V4ServerConf{
GatewayIP: j.GatewayIP,
SubnetMask: j.SubnetMask,
RangeStart: j.RangeStart,
RangeEnd: j.RangeEnd,
GatewayIP: j.GatewayIP.To4(),
SubnetMask: j.SubnetMask.To4(),
RangeStart: j.RangeStart.To4(),
RangeEnd: j.RangeEnd.To4(),
LeaseDuration: j.LeaseDuration,
}
}
@ -117,7 +117,7 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
type staticLeaseJSON struct {
HWAddr string `json:"mac"`
IP string `json:"ip"`
IP net.IP `json:"ip"`
Hostname string `json:"hostname"`
}
@ -225,10 +225,10 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
type netInterfaceJSON struct {
Name string `json:"name"`
GatewayIP string `json:"gateway_ip"`
GatewayIP net.IP `json:"gateway_ip"`
HardwareAddr string `json:"hardware_address"`
Addrs4 []string `json:"ipv4_addresses"`
Addrs6 []string `json:"ipv6_addresses"`
Addrs4 []net.IP `json:"ipv4_addresses"`
Addrs6 []net.IP `json:"ipv6_addresses"`
Flags string `json:"flags"`
}
@ -277,9 +277,9 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
continue
}
if ipnet.IP.To4() != nil {
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String())
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
} else {
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String())
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
}
}
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
@ -375,50 +375,46 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if lj.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
ip4 := lj.IP.To4()
mac, err := net.ParseMAC(lj.HWAddr)
lease := Lease{
HWAddr: mac,
}
if ip4 == nil {
lease.IP = lj.IP.To16()
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
return
}
err = s.srv6.AddStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
lease.IP = ip4
lease.Hostname = lj.Hostname
err = s.srv4.AddStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}
@ -428,46 +424,46 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if lj.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
ip4 := lj.IP.To4()
mac, err := net.ParseMAC(lj.HWAddr)
lease := Lease{
HWAddr: mac,
}
if ip4 == nil {
lease.IP = lj.IP.To16()
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
return
}
err = s.srv6.RemoveStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
lease.IP = ip4
lease.Hostname = lj.Hostname
err = s.srv4.RemoveStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}

View File

@ -14,15 +14,17 @@ func isTimeout(err error) bool {
return operr.Timeout()
}
func parseIPv4(text string) (net.IP, error) {
result := net.ParseIP(text)
if result == nil {
return nil, fmt.Errorf("%s is not an IP address", text)
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
if ip == nil {
return nil, fmt.Errorf("%v is not an IP address", ip)
}
if result.To4() == nil {
return nil, fmt.Errorf("%s is not an IPv4 address", text)
ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
}
return result.To4(), nil
return ip4, nil
}
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)

View File

@ -36,13 +36,13 @@ type V4ServerConf struct {
Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"`
GatewayIP string `yaml:"gateway_ip"`
SubnetMask string `yaml:"subnet_mask"`
GatewayIP net.IP `yaml:"gateway_ip"`
SubnetMask net.IP `yaml:"subnet_mask"`
// The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart string `yaml:"range_start"`
RangeEnd string `yaml:"range_end"`
RangeStart net.IP `yaml:"range_start"`
RangeEnd net.IP `yaml:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds

View File

@ -589,7 +589,7 @@ func (s *v4Server) Start() error {
s.conf.dnsIPAddrs = dnsIPAddrs
laddr := &net.UDPAddr{
IP: net.ParseIP("0.0.0.0"),
IP: net.IP{0, 0, 0, 0},
Port: dhcpv4.ServerPort,
}
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
@ -632,19 +632,18 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
}
var err error
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP)
s.conf.routerIP, err = tryTo4(s.conf.GatewayIP)
if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
subnet, err := parseIPv4(s.conf.SubnetMask)
if err != nil || !isValidSubnetMask(subnet) {
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %s", s.conf.SubnetMask)
if s.conf.SubnetMask == nil {
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
}
s.conf.subnetMask = make([]byte, 4)
copy(s.conf.subnetMask, subnet)
copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
s.conf.ipStart, err = parseIPv4(conf.RangeStart)
s.conf.ipStart, err = tryTo4(conf.RangeStart)
if s.conf.ipStart == nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
@ -652,7 +651,7 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
return s, fmt.Errorf("dhcpv4: invalid range start IP")
}
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd)
s.conf.ipEnd, err = tryTo4(conf.RangeEnd)
if s.conf.ipEnd == nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}

View File

@ -16,119 +16,119 @@ func notify4(flags uint32) {
func TestV4StaticLeaseAddRemove(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
s, err := v4Create(conf)
assert.True(t, err == nil)
assert.Nil(t, err)
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
// add static lease
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// try to add the same static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
assert.NotNil(t, s.AddStaticLease(l))
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// try to remove static lease - fail
l.IP = net.ParseIP("192.168.10.110").To4()
l.IP = net.IP{192, 168, 10, 110}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
assert.NotNil(t, s.RemoveStaticLease(l))
// remove static lease
l.IP = net.ParseIP("192.168.10.150").To4()
l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
assert.Nil(t, s.RemoveStaticLease(l))
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
}
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
assert.Nil(t, err)
// add dynamic lease
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.150").To4()
ld.IP = net.IP{192, 168, 10, 150}
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld)
// add dynamic lease
{
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.151").To4()
ld.IP = net.IP{192, 168, 10, 151}
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld)
}
// add static lease with the same IP
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("192.168.10.152").To4()
l.IP = net.IP{192, 168, 10, 152}
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
assert.Len(t, ls, 2)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
}
func TestV4StaticLeaseGet(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -160,12 +160,12 @@ func TestV4StaticLeaseGet(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
}
@ -173,10 +173,10 @@ func TestV4StaticLeaseGet(t *testing.T) {
func TestV4DynamicLeaseGet(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
Options: []string{
"81 hex 303132",
@ -185,8 +185,8 @@ func TestV4DynamicLeaseGet(t *testing.T) {
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -220,19 +220,19 @@ func TestV4DynamicLeaseGet(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.ParseIP("192.168.10.100").To4()
stop := net.ParseIP("192.168.10.200").To4()
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4()))
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4()))
start := net.IP{192, 168, 10, 100}
stop := net.IP{192, 168, 10, 200}
assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 10, 99}))
assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 100}))
assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 201}))
assert.True(t, ip4InRange(start, stop, net.IP{192, 168, 10, 100}))
}

View File

@ -21,40 +21,40 @@ func TestV6StaticLeaseAddRemove(t *testing.T) {
notify: notify6,
}
s, err := v6Create(conf)
assert.True(t, err == nil)
assert.Nil(t, err)
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
// add static lease
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// try to add static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
assert.NotNil(t, s.AddStaticLease(l))
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// try to remove static lease - fail
l.IP = net.ParseIP("2001::2")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
assert.NotNil(t, s.RemoveStaticLease(l))
// remove static lease
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
assert.Nil(t, s.RemoveStaticLease(l))
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
}
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
@ -65,7 +65,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
assert.Nil(t, err)
// add dynamic lease
ld := Lease{}
@ -85,25 +85,25 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("2001::3")
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
assert.Len(t, ls, 2)
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "2001::3", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
}
func TestV6GetLease(t *testing.T) {
@ -114,7 +114,7 @@ func TestV6GetLease(t *testing.T) {
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
@ -125,7 +125,7 @@ func TestV6GetLease(t *testing.T) {
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
assert.Nil(t, s.AddStaticLease(l))
// "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -156,12 +156,12 @@ func TestV6GetLease(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
}
@ -174,7 +174,7 @@ func TestV6GetDynamicLease(t *testing.T) {
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
@ -209,17 +209,17 @@ func TestV6GetDynamicLease(t *testing.T) {
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Len(t, ls, 1)
assert.Equal(t, "2001::2", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
}

View File

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"net"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
@ -135,7 +134,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
assert.Len(t, res.Rules[0].IP, 0)
assert.Empty(t, res.Rules[0].IP)
}
// IPv6
@ -147,7 +146,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
assert.Len(t, res.Rules[0].IP, 0)
assert.Empty(t, res.Rules[0].IP)
}
// 2 IPv4 (return only the first one)
@ -180,7 +179,7 @@ func TestSafeBrowsing(t *testing.T) {
defer d.Close()
d.checkMatch(t, "wmconvirus.narod.ru")
assert.True(t, strings.Contains(logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru"))
assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru")
d.checkMatch(t, "test.wmconvirus.narod.ru")
d.checkMatchEmpty(t, "yandex.ru")
@ -268,7 +267,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Len(t, res.Rules, 0)
assert.Empty(t, res.Rules)
d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close()
@ -298,7 +297,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Len(t, res.Rules, 0)
assert.Empty(t, res.Rules)
d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close()
@ -346,7 +345,7 @@ func TestParentalControl(t *testing.T) {
d := NewForTest(&Config{ParentalEnabled: true}, nil)
defer d.Close()
d.checkMatch(t, "pornhub.com")
assert.True(t, strings.Contains(logOutput.String(), "Parental lookup for pornhub.com"))
assert.Contains(t, logOutput.String(), "Parental lookup for pornhub.com")
d.checkMatch(t, "www.pornhub.com")
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
@ -468,18 +467,20 @@ func TestWhitelist(t *testing.T) {
// matched by white filter
res, err := d.CheckHost("host1", dns.TypeA, &setts)
assert.True(t, err == nil)
assert.True(t, !res.IsFiltered && res.Reason == NotFilteredAllowList)
assert.Nil(t, err)
assert.False(t, res.IsFiltered)
assert.Equal(t, res.Reason, NotFilteredAllowList)
if assert.Len(t, res.Rules, 1) {
assert.True(t, res.Rules[0].Text == "||host1^")
assert.Equal(t, "||host1^", res.Rules[0].Text)
}
// not matched by white filter, but matched by block filter
res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.True(t, err == nil)
assert.True(t, res.IsFiltered && res.Reason == FilteredBlockList)
assert.Nil(t, err)
assert.True(t, res.IsFiltered)
assert.Equal(t, res.Reason, FilteredBlockList)
if assert.Len(t, res.Rules, 1) {
assert.True(t, res.Rules[0].Text == "||host2^")
assert.Equal(t, "||host2^", res.Rules[0].Text)
}
}
@ -529,7 +530,7 @@ func TestClientSettings(t *testing.T) {
// not blocked
r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
assert.True(t, !r.IsFiltered)
assert.False(t, r.IsFiltered)
// override client settings:
applyClientSettings(&setts)
@ -554,7 +555,8 @@ func TestClientSettings(t *testing.T) {
// blocked by additional rules
r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService)
assert.True(t, r.IsFiltered)
assert.Equal(t, r.Reason, FilteredBlockedService)
}
// BENCHMARKS

View File

@ -171,7 +171,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "", res.CanonName)
assert.Empty(t, res.CanonName)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
@ -197,7 +197,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "", res.CanonName)
assert.Len(t, res.Rules, 0)
assert.Empty(t, res.CanonName)
assert.Empty(t, res.Rules)
})
}

View File

@ -27,14 +27,14 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 2, len(r.IPList))
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))
assert.Len(t, r.IPList, 2)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
// wildcard
@ -45,11 +45,11 @@ func TestRewrites(t *testing.T) {
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
@ -62,8 +62,8 @@ func TestRewrites(t *testing.T) {
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// wildcard + CNAME
d.Rewrites = []RewriteEntry{
@ -74,7 +74,7 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs
d.Rewrites = []RewriteEntry{
@ -86,8 +86,8 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{
@ -99,8 +99,8 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
}
func TestRewritesLevels(t *testing.T) {
@ -116,19 +116,19 @@ func TestRewritesLevels(t *testing.T) {
// match exact
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.1.1.1", r.IPList[0].String())
// match L2
r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "3.3.3.3", r.IPList[0].String())
}
@ -144,7 +144,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match sub-domain, but handle exception
@ -164,7 +164,7 @@ func TestRewritesExceptionWC(t *testing.T) {
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match sub-domain, but handle exception
@ -187,7 +187,7 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.2.3.4", r.IPList[0].String())
// match exception
@ -201,7 +201,7 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r = d.processRewrites("host2.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.Equal(t, "::1", r.IPList[0].String())
// match exception
@ -211,5 +211,5 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r = d.processRewrites("host3.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 0, len(r.IPList))
assert.Empty(t, r.IPList)
}

View File

@ -37,8 +37,8 @@ func (d *DNSFilter) initSecurityServices() error {
opts := upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
net.ParseIP("94.140.14.15"),
net.ParseIP("94.140.15.16"),
{94, 140, 14, 15},
{94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"),
},

View File

@ -14,7 +14,7 @@ import (
func TestSafeBrowsingHash(t *testing.T) {
// test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Equal(t, 3, len(hashes))
assert.Len(t, hashes, 3)
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
@ -31,9 +31,9 @@ func TestSafeBrowsingHash(t *testing.T) {
q := c.getQuestion()
assert.True(t, strings.Contains(q, "7a1b."))
assert.True(t, strings.Contains(q, "af5a."))
assert.True(t, strings.Contains(q, "eb11."))
assert.Contains(t, q, "7a1b.")
assert.Contains(t, q, "af5a.")
assert.Contains(t, q, "eb11.")
assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com."))
}
@ -81,7 +81,7 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com"
assert.Equal(t, 0, c.getCached())
assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
@ -103,7 +103,7 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
c.cache.Set(hash[0:2], make([]byte, 32))
assert.Equal(t, 0, c.getCached())
assert.Empty(t, c.getCached())
}
// testErrUpstream implements upstream.Upstream interface for replacing real

View File

@ -8,28 +8,28 @@ import (
func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{}
assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil)
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
}
func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{}
assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil)
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
assert.True(t, disallowed)
@ -37,7 +37,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
assert.True(t, disallowed)
@ -45,7 +45,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
}
func TestIsBlockedIPBlockedDomain(t *testing.T) {
@ -60,13 +60,13 @@ func TestIsBlockedIPBlockedDomain(t *testing.T) {
// match by "host2.com"
assert.True(t, a.IsBlockedDomain("host1"))
assert.True(t, a.IsBlockedDomain("host2"))
assert.True(t, !a.IsBlockedDomain("host3"))
assert.False(t, a.IsBlockedDomain("host3"))
// match by wildcard "*.host.com"
assert.True(t, !a.IsBlockedDomain("host.com"))
assert.False(t, a.IsBlockedDomain("host.com"))
assert.True(t, a.IsBlockedDomain("asdf.host.com"))
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
assert.True(t, !a.IsBlockedDomain("asdf.zhost.com"))
assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
// match by wildcard "||host3.com^"
assert.True(t, a.IsBlockedDomain("host3.com"))

View File

@ -29,17 +29,16 @@ type FilteringConfig struct {
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
// TODO(e.burkov): replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration
// --
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
BlockingIPAddrv4 net.IP `yaml:"-"`
BlockingIPAddrv6 net.IP `yaml:"-"`
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
// IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing

View File

@ -182,7 +182,7 @@ func processInternalHosts(ctx *dnsContext) int {
return resultDone
}
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip.String())
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
resp := s.makeResponse(req)
@ -278,7 +278,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
return resultDone
}
// Pass request to upstream servers; process the response
// processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
@ -287,7 +287,7 @@ func processUpstream(ctx *dnsContext) int {
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := ipFromAddr(d.Addr)
clientIP := IPStringFromAddr(d.Addr)
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
if upstreamsConf != nil {
log.Debug("Using custom upstreams for %s", clientIP)

View File

@ -178,9 +178,7 @@ func (s *Server) Prepare(config *ServerConfig) error {
if config != nil {
s.conf = *config
if s.conf.BlockingMode == "custom_ip" {
s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4)
s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6)
if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil {
if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil {
return fmt.Errorf("dns: invalid custom blocking IP address specified")
}
}

View File

@ -286,7 +286,7 @@ func TestBlockedRequest(t *testing.T) {
t.Fatalf("Couldn't talk to server %s: %s", addr, err)
}
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
err = s.Stop()
if err != nil {
@ -300,7 +300,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
uc := &proxy.UpstreamConfig{}
u := &testUpstream{}
u.ipv4 = map[string][]net.IP{}
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")}
u.ipv4["host."] = []net.IP{{192, 168, 0, 1}}
uc.Upstreams = append(uc.Upstreams, u)
return uc
}
@ -425,7 +425,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil)
assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
@ -440,16 +440,16 @@ func TestBlockCNAME(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil)
assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// response is blocked
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err, nil)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist:
@ -465,7 +465,7 @@ func TestBlockCNAME(t *testing.T) {
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0")))
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
_ = s.Stop()
}
@ -548,13 +548,13 @@ func TestBlockedCustomIP(t *testing.T) {
conf.TCPListenAddr = &net.TCPAddr{Port: 0}
conf.ProtectionEnabled = true
conf.BlockingMode = "custom_ip"
conf.BlockingIPv4 = "bad IP"
conf.BlockingIPv4 = nil
conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
err := s.Prepare(&conf)
assert.True(t, err != nil) // invalid BlockingIPv4
assert.NotNil(t, err) // invalid BlockingIPv4
conf.BlockingIPv4 = "0.0.0.1"
conf.BlockingIPv6 = "::1"
conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
conf.BlockingIPv6 = net.ParseIP("::1")
err = s.Prepare(&conf)
assert.Nil(t, err)
err = s.Start()
@ -565,7 +565,7 @@ func TestBlockedCustomIP(t *testing.T) {
req := createTestMessageWithType("null.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer))
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "0.0.0.1", a.A.String())
@ -573,7 +573,7 @@ func TestBlockedCustomIP(t *testing.T) {
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer))
assert.Len(t, reply.Answer, 1)
a6, ok := reply.Answer[0].(*dns.AAAA)
assert.True(t, ok)
assert.Equal(t, "::1", a6.AAAA.String())
@ -710,7 +710,7 @@ func TestRewrite(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer))
assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "1.2.3.4", a.A.String())
@ -718,12 +718,12 @@ func TestRewrite(t *testing.T) {
req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 0, len(reply.Answer))
assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 2, len(reply.Answer))
assert.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String())
@ -731,7 +731,7 @@ func TestRewrite(t *testing.T) {
reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored
assert.Equal(t, 2, len(reply.Answer))
assert.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
@ -765,7 +765,7 @@ func createTestServer(t *testing.T) *Server {
s.conf.ConfigModified = func() {}
err := s.Prepare(nil)
assert.True(t, err == nil)
assert.Nil(t, err)
return s
}
@ -1011,16 +1011,14 @@ func TestValidateUpstreamsSet(t *testing.T) {
assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation")
}
func TestIpFromAddr(t *testing.T) {
func TestIPStringFromAddr(t *testing.T) {
addr := net.UDPAddr{}
addr.IP = net.ParseIP("1:2:3::4")
addr.Port = 12345
addr.Zone = "eth0"
a := ipFromAddr(&addr)
assert.True(t, a == "1:2:3::4")
assert.Equal(t, IPStringFromAddr(&addr), net.ParseIP("1:2:3::4").String())
a = ipFromAddr(nil)
assert.True(t, a == "")
assert.Empty(t, IPStringFromAddr(nil))
}
func TestMatchDNSName(t *testing.T) {
@ -1030,9 +1028,9 @@ func TestMatchDNSName(t *testing.T) {
assert.True(t, matchDNSName(dnsNames, "a.host2"))
assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
assert.True(t, !matchDNSName(dnsNames, "host2"))
assert.True(t, !matchDNSName(dnsNames, ""))
assert.True(t, !matchDNSName(dnsNames, "*.host2"))
assert.False(t, matchDNSName(dnsNames, "host2"))
assert.False(t, matchDNSName(dnsNames, ""))
assert.False(t, matchDNSName(dnsNames, "*.host2"))
}
type testDHCP struct {
@ -1040,7 +1038,7 @@ type testDHCP struct {
func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
l := dhcpd.Lease{}
l.IP = net.ParseIP("127.0.0.1").To4()
l.IP = net.IP{127, 0, 0, 1}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
l.Hostname = "localhost"
return []dhcpd.Lease{l}
@ -1058,7 +1056,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil)
assert.True(t, err == nil)
assert.Nil(t, err)
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1067,7 +1065,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer))
assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR)
@ -1100,7 +1098,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil)
assert.True(t, err == nil)
assert.Nil(t, err)
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1109,7 +1107,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer))
assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR)

View File

@ -12,7 +12,7 @@ import (
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := ipFromAddr(d.Addr)
ip := IPStringFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
@ -36,7 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {
clientAddr := ipFromAddr(d.Addr)
clientAddr := IPStringFromAddr(d.Addr)
s.conf.FilterHandler(clientAddr, &setts)
}
return &setts

View File

@ -28,8 +28,8 @@ type dnsConfig struct {
ProtectionEnabled *bool `json:"protection_enabled"`
RateLimit *uint32 `json:"ratelimit"`
BlockingMode *string `json:"blocking_mode"`
BlockingIPv4 *string `json:"blocking_ipv4"`
BlockingIPv6 *string `json:"blocking_ipv6"`
BlockingIPv4 net.IP `json:"blocking_ipv4"`
BlockingIPv6 net.IP `json:"blocking_ipv6"`
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
DNSSECEnabled *bool `json:"dnssec_enabled"`
DisableIPv6 *bool `json:"disable_ipv6"`
@ -68,8 +68,8 @@ func (s *Server) getDNSConfig() dnsConfig {
Bootstraps: &bootstraps,
ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode,
BlockingIPv4: &BlockingIPv4,
BlockingIPv6: &BlockingIPv6,
BlockingIPv4: BlockingIPv4,
BlockingIPv6: BlockingIPv6,
RateLimit: &Ratelimit,
EDNSCSEnabled: &EnableEDNSClientSubnet,
DNSSECEnabled: &EnableDNSSEC,
@ -100,17 +100,11 @@ func (req *dnsConfig) checkBlockingMode() bool {
bm := *req.BlockingMode
if bm == "custom_ip" {
if req.BlockingIPv4 == nil || req.BlockingIPv6 == nil {
if req.BlockingIPv4.To4() == nil {
return false
}
ip4 := net.ParseIP(*req.BlockingIPv4)
if ip4 == nil || ip4.To4() == nil {
return false
}
ip6 := net.ParseIP(*req.BlockingIPv6)
return ip6 != nil
return req.BlockingIPv6 != nil
}
for _, valid := range []string{
@ -247,10 +241,8 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) {
if dc.BlockingMode != nil {
s.conf.BlockingMode = *dc.BlockingMode
if *dc.BlockingMode == "custom_ip" {
s.conf.BlockingIPv4 = *dc.BlockingIPv4
s.conf.BlockingIPAddrv4 = net.ParseIP(*dc.BlockingIPv4)
s.conf.BlockingIPv6 = *dc.BlockingIPv6
s.conf.BlockingIPAddrv6 = net.ParseIP(*dc.BlockingIPv6)
s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
}
}

View File

@ -60,9 +60,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
switch m.Question[0].Qtype {
case dns.TypeA:
return s.genARecord(m, s.conf.BlockingIPAddrv4)
return s.genARecord(m, s.conf.BlockingIPv4)
case dns.TypeAAAA:
return s.genAAAARecord(m, s.conf.BlockingIPAddrv6)
return s.genAAAARecord(m, s.conf.BlockingIPv6)
}
} else if s.conf.BlockingMode == "nxdomain" {
// means that we should return NXDOMAIN for any blocked request

View File

@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: getIP(d.Addr),
ClientIP: ipFromAddr(d.Addr),
}
switch d.Proto {

View File

@ -8,38 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils"
)
// GetIPString is a helper function that extracts IP address from net.Addr
func GetIPString(addr net.Addr) string {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Get IP address from net.Addr object
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func ipFromAddr(a net.Addr) string {
switch addr := a.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
// Get IP address from net.Addr
func getIP(addr net.Addr) net.IP {
// ipFromAddr gets IP address from addr.
func ipFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
@ -49,6 +19,23 @@ func getIP(addr net.Addr) net.IP {
return nil
}
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipstr string) {
if ip := ipFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Find value in a sorted array
func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val)

View File

@ -70,7 +70,7 @@ func TestAuth(t *testing.T) {
a.Close()
u := a.UserFind("name", "password")
assert.True(t, len(u.Name) != 0)
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
@ -125,9 +125,9 @@ func TestAuthHTTP(t *testing.T) {
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.statusCode == http.StatusFound)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
assert.Equal(t, http.StatusFound, w.statusCode)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
// go to login page
loginURL := w.hdr.Get("Location")
@ -139,7 +139,7 @@ func TestAuthHTTP(t *testing.T) {
// perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.Nil(t, err)
assert.True(t, cookie != "")
assert.NotEmpty(t, cookie)
// get /
handler2 = optionalAuth(handler)
@ -168,8 +168,8 @@ func TestAuthHTTP(t *testing.T) {
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
r.Header.Del("Cookie")
// get login page with an invalid cookie

View File

@ -37,15 +37,18 @@ func TestClients(t *testing.T) {
assert.Nil(t, err)
c, b = clients.Find("1.1.1.1")
assert.True(t, b && c.Name == "client1")
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("1:2:3::4")
assert.True(t, b && c.Name == "client1")
assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("2.2.2.2")
assert.True(t, b && c.Name == "client2")
assert.True(t, b)
assert.Equal(t, c.Name, "client2")
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
})
@ -109,7 +112,7 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c)
assert.Nil(t, err)
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = Client{
@ -123,8 +126,8 @@ func TestClients(t *testing.T) {
c, b := clients.Find("1.1.1.2")
assert.True(t, b)
assert.True(t, c.Name == "client1-renamed")
assert.True(t, c.IDs[0] == "1.1.1.2")
assert.Equal(t, "client1-renamed", c.Name)
assert.Equal(t, "1.1.1.2", c.IDs[0])
assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"])
})
@ -172,12 +175,12 @@ func TestClientsWhois(t *testing.T) {
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val")
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])
// Check that we cannot set whois info on a manually-added client
c = Client{
@ -186,7 +189,7 @@ func TestClientsWhois(t *testing.T) {
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
}
@ -272,6 +275,6 @@ func TestClientsCustomUpstream(t *testing.T) {
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
assert.Len(t, config.Upstreams, 1)
assert.Len(t, config.DomainReservedUpstreams, 1)
}

View File

@ -98,7 +98,7 @@ func isRunning() bool {
}
func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.GetIPString(d.Addr)
ip := dnsforward.IPStringFromAddr(d.Addr)
if ip == "" {
// This would be quite weird if we get here
return

View File

@ -50,16 +50,17 @@ func TestFilters(t *testing.T) {
// download
ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err)
assert.Nil(t, err)
assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil)
assert.False(t, ok)
assert.Nil(t, err)
err = Context.filters.load(&f)
assert.True(t, err == nil)
assert.Nil(t, err)
f.unload()
_ = os.Remove(f.Path())

View File

@ -119,7 +119,7 @@ func TestHome(t *testing.T) {
fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644) == nil)
assert.Nil(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644))
fn, _ = filepath.Abs(fn)
config = configuration{} // the global variable is dirty because of the previous tests run
@ -138,11 +138,11 @@ func TestHome(t *testing.T) {
}
time.Sleep(100 * time.Millisecond)
}
assert.Truef(t, err == nil, "%s", err)
assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = h.Get("http://127.0.0.1:3000/control/status")
assert.Truef(t, err == nil, "%s", err)
assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// test DNS over UDP
@ -159,16 +159,16 @@ func TestHome(t *testing.T) {
req.RecursionDesired = true
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
buf, err := req.Pack()
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
resp, err = http.DefaultClient.Get(requestURL)
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
body, err := ioutil.ReadAll(resp.Body)
assert.True(t, err == nil, "%s", err)
assert.True(t, resp.StatusCode == http.StatusOK)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
response := dns.Msg{}
err = response.Unpack(body)
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0

View File

@ -23,7 +23,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -51,7 +51,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -89,7 +89,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -116,7 +116,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)

View File

@ -12,10 +12,10 @@ func TestResolveRDNS(t *testing.T) {
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
err := dns.Prepare(conf)
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
clients := &clientsContainer{}
rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1")
assert.True(t, r == "one.one.one.one", "%s", r)
assert.Equal(t, "one.one.one.one", r, r)
}

View File

@ -84,7 +84,7 @@ func TestDecodeLogEntry(t *testing.T) {
decodeLogEntry(got, data)
s := logOutput.String()
assert.Equal(t, "", s)
assert.Empty(t, s)
// Correct for time zones.
got.Time = got.Time.UTC()
@ -172,7 +172,7 @@ func TestDecodeLogEntry(t *testing.T) {
s := logOutput.String()
if tc.want == "" {
assert.Equal(t, "", s)
assert.Empty(t, s)
} else {
assert.True(t, strings.HasSuffix(s, tc.want),
"got %q", s)

View File

@ -56,7 +56,7 @@ func TestQueryLog(t *testing.T) {
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Equal(t, 4, len(entries))
assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
@ -70,7 +70,7 @@ func TestQueryLog(t *testing.T) {
value: "TEST.example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict)
@ -81,7 +81,7 @@ func TestQueryLog(t *testing.T) {
value: "example.ORG",
})
entries, _ = l.search(params)
assert.Equal(t, 3, len(entries))
assert.Len(t, entries, 3)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
@ -94,7 +94,7 @@ func TestQueryLog(t *testing.T) {
value: "2.2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of)
@ -105,7 +105,7 @@ func TestQueryLog(t *testing.T) {
value: "2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 4, len(entries))
assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
@ -138,7 +138,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 0
params.limit = 10
entries, _ := l.search(params)
assert.Equal(t, 10, len(entries))
assert.Len(t, entries, 10)
assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org")
@ -146,7 +146,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 10
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Len(t, entries, 10)
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org")
@ -154,7 +154,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 15
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 5, len(entries))
assert.Len(t, entries, 5)
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org")
@ -162,7 +162,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 20
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 0, len(entries))
assert.Empty(t, entries)
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
@ -186,11 +186,11 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params)
assert.Equal(t, 5, len(entries))
assert.Len(t, entries, 5)
params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Len(t, entries, 10)
}
func TestQueryLogFileDisabled(t *testing.T) {
@ -211,7 +211,7 @@ func TestQueryLogFileDisabled(t *testing.T) {
params := newSearchParams()
ll, _ := l.search(params)
assert.Equal(t, 2, len(ll))
assert.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost)
}
@ -262,7 +262,7 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer))
assert.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())

View File

@ -28,12 +28,12 @@ func TestQLogFileEmpty(t *testing.T) {
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.Equal(t, int64(0), pos)
assert.EqualValues(t, 0, pos)
// try reading anyway
line, err := q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
assert.Empty(t, line)
}
func TestQLogFileLarge(t *testing.T) {
@ -53,14 +53,14 @@ func TestQLogFileLarge(t *testing.T) {
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
assert.NotEqualValues(t, 0, pos)
read := 0
var line string
for err == nil {
line, err = q.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
assert.NotZero(t, len(line))
read++
}
}
@ -109,10 +109,10 @@ func TestQLogFileSeekLargeFile(t *testing.T) {
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
assert.NotEqualValues(t, 0, timestamp)
_, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3))
}
func TestQLogFileSeekSmallFile(t *testing.T) {
@ -155,22 +155,22 @@ func TestQLogFileSeekSmallFile(t *testing.T) {
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
assert.NotEqualValues(t, 0, timestamp)
_, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3))
}
func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) {
line, err := getQLogFileLine(q, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
assert.NotEqualValues(t, 0, ts)
// try seeking to that line now
pos, _, err := q.SeekTS(ts)
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext()
assert.Nil(t, err)
@ -207,27 +207,27 @@ func TestQLogFile(t *testing.T) {
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.True(t, pos > 0)
assert.Greater(t, pos, int64(0))
// read first line
line, err := q.ReadNext()
assert.Nil(t, err)
assert.True(t, strings.Contains(line, "0.0.0.2"), line)
assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// read second line
line, err = q.ReadNext()
assert.Nil(t, err)
assert.Equal(t, int64(0), q.position)
assert.True(t, strings.Contains(line, "0.0.0.1"), line)
assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// try reading again (there's nothing to read anymore)
line, err = q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
assert.Empty(t, line)
}
// prepareTestFile - prepares a test query log file with the specified number of lines

View File

@ -21,7 +21,7 @@ func TestQLogReaderEmpty(t *testing.T) {
assert.Nil(t, err)
line, err := r.ReadNext()
assert.Equal(t, "", line)
assert.Empty(t, line)
assert.Equal(t, io.EOF, err)
}
@ -241,7 +241,7 @@ func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) {
line, err := getQLogReaderLine(r, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
assert.NotEqualValues(t, 0, ts)
// try seeking to that line now
err = r.SeekTS(ts)

View File

@ -39,13 +39,13 @@ func TestStats(t *testing.T) {
e := Entry{}
e.Domain = "domain"
e.Client = net.ParseIP("127.0.0.1")
e.Client = net.IP{127, 0, 0, 1}
e.Result = RFiltered
e.Time = 123456
s.Update(e)
e.Domain = "domain"
e.Client = net.ParseIP("127.0.0.1")
e.Client = net.IP{127, 0, 0, 1}
e.Result = RNotFiltered
e.Time = 123456
s.Update(e)
@ -64,23 +64,23 @@ func TestStats(t *testing.T) {
assert.True(t, UIntArrayEquals(d["replaced_parental"].([]uint64), a))
m := d["top_queried_domains"].([]map[string]uint64)
assert.True(t, m[0]["domain"] == 1)
assert.EqualValues(t, 1, m[0]["domain"])
m = d["top_blocked_domains"].([]map[string]uint64)
assert.True(t, m[0]["domain"] == 1)
assert.EqualValues(t, 1, m[0]["domain"])
m = d["top_clients"].([]map[string]uint64)
assert.True(t, m[0]["127.0.0.1"] == 2)
assert.EqualValues(t, 2, m[0]["127.0.0.1"])
assert.True(t, d["num_dns_queries"].(uint64) == 2)
assert.True(t, d["num_blocked_filtering"].(uint64) == 1)
assert.True(t, d["num_replaced_safebrowsing"].(uint64) == 0)
assert.True(t, d["num_replaced_safesearch"].(uint64) == 0)
assert.True(t, d["num_replaced_parental"].(uint64) == 0)
assert.True(t, d["avg_processing_time"].(float64) == 0.123456)
assert.EqualValues(t, 2, d["num_dns_queries"].(uint64))
assert.EqualValues(t, 1, d["num_blocked_filtering"].(uint64))
assert.EqualValues(t, 0, d["num_replaced_safebrowsing"].(uint64))
assert.EqualValues(t, 0, d["num_replaced_safesearch"].(uint64))
assert.EqualValues(t, 0, d["num_replaced_parental"].(uint64))
assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64))
topClients := s.GetTopClientsIP(2)
assert.True(t, topClients[0] == "127.0.0.1")
assert.Equal(t, "127.0.0.1", topClients[0])
s.clear()
s.Close()
@ -111,7 +111,7 @@ func TestLargeNumbers(t *testing.T) {
}
for i := 0; i != n; i++ {
e.Domain = fmt.Sprintf("domain%d", i)
e.Client = net.ParseIP("127.0.0.1")
e.Client = net.IP{127, 0, 0, 1}
e.Client[2] = byte((i & 0xff00) >> 8)
e.Client[3] = byte(i & 0xff)
e.Result = RNotFiltered
@ -121,7 +121,7 @@ func TestLargeNumbers(t *testing.T) {
}
d := s.getData()
assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n))
assert.EqualValues(t, int(hour)*n, d["num_dns_queries"])
s.Close()
os.Remove(conf.Filename)
@ -152,6 +152,6 @@ func aggregateDataPerDay(firstID uint32) int {
func TestAggregateDataPerTimeUnit(t *testing.T) {
for i := 0; i != 25; i++ {
alen := aggregateDataPerDay(uint32(i))
assert.True(t, alen == 30, "i=%d", i)
assert.Equalf(t, 30, alen, "i=%d", i)
}
}

View File

@ -19,12 +19,12 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
}
// GatewayIP returns IP address of interface's gateway.
func GatewayIP(ifaceName string) string {
func GatewayIP(ifaceName string) net.IP {
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return ""
return nil
}
fields := strings.Fields(string(d))
@ -32,13 +32,8 @@ func GatewayIP(ifaceName string) string {
// "default" at first field and default gateway IP address at third
// field.
if len(fields) < 3 || fields[0] != "default" {
return ""
return nil
}
ip := net.ParseIP(fields[2])
if ip == nil {
return ""
}
return fields[2]
return net.ParseIP(fields[2])
}

View File

@ -129,7 +129,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
return err
}
gatewayIP := GatewayIP(ifaceName)
add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4.String())
add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4)
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {
@ -147,14 +147,14 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// updateStaticIPdhcpcdConf sets static IP address for the interface by writing
// into dhcpd.conf.
func updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
func updateStaticIPdhcpcdConf(ifaceName, ip string, gatewayIP, dnsIP net.IP) string {
var body []byte
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
ifaceName, ip)
body = append(body, []byte(add)...)
if len(gatewayIP) != 0 {
if gatewayIP != nil {
add = fmt.Sprintf("static routers=%s\n",
gatewayIP)
body = append(body, []byte(add)...)

View File

@ -4,6 +4,7 @@ package sysutil
import (
"bytes"
"net"
"testing"
"github.com/stretchr/testify/assert"
@ -96,7 +97,7 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) {
`static routers=192.168.0.1` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl
s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2")
s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 2})
assert.Equal(t, dhcpcdConf, s)
// without gateway
@ -104,6 +105,6 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) {
`static ip_address=192.168.0.2/24` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl
s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2")
s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", nil, net.IP{192, 168, 0, 2})
assert.Equal(t, dhcpcdConf, s)
}

View File

@ -42,7 +42,7 @@ func TestAutoHostsResolution(t *testing.T) {
// Existing host
ips := ah.Process("localhost", dns.TypeA)
assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips))
assert.Len(t, ips, 1)
assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0])
// Unknown host
@ -107,7 +107,7 @@ func TestAutoHostsFSNotify(t *testing.T) {
// Check if we are notified about changes
ips = ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips))
assert.Len(t, ips, 1)
assert.Equal(t, "127.0.0.2", ips[0].String())
}

View File

@ -8,7 +8,8 @@ import (
func TestSplitNext(t *testing.T) {
s := " a,b , c "
assert.True(t, SplitNext(&s, ',') == "a")
assert.True(t, SplitNext(&s, ',') == "b")
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
assert.Equal(t, "a", SplitNext(&s, ','))
assert.Equal(t, "b", SplitNext(&s, ','))
assert.Equal(t, "c", SplitNext(&s, ','))
assert.Empty(t, s)
}