Pull request: 2846 cover aghnet vol.4

Merge in DNS/adguard-home from 2846-cover-aghnet-vol.4 to master

Updates #2846.

Squashed commit of the following:

commit 576ef857628a403ce1478c10a4aad23985c09613
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Mar 31 19:38:57 2022 +0300

    aghnet: imp code

commit 5b4b17ff52867aaab2c9d30a0fc7fc2fe31ff4d5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Mar 31 14:58:34 2022 +0300

    aghnet: imp coverage
This commit is contained in:
Eugene Burkov 2022-03-31 19:56:50 +03:00
parent a79b61aac3
commit c70f941bf8
15 changed files with 266 additions and 207 deletions

View File

@ -181,6 +181,16 @@ func TestCmdARPDB_arpa(t *testing.T) {
err := a.Refresh() err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err) testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
}) })
t.Run("empty", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", nil)
substShell(t, sh.RunCmd)
err := a.Refresh()
require.NoError(t, err)
assert.Empty(t, a.Neighbors())
})
} }
func TestEmptyARPDB(t *testing.T) { func TestEmptyARPDB(t *testing.T) {

View File

@ -11,7 +11,7 @@ const (
ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010") ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010")
) )
// generateIPv4Hostname generates the hostname for specific IP version. // generateIPv4Hostname generates the hostname by IP address version 4.
func generateIPv4Hostname(ipv4 net.IP) (hostname string) { func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv4HostnameMaxLen) hnData := make([]byte, 0, ipv4HostnameMaxLen)
for i, part := range ipv4 { for i, part := range ipv4 {
@ -24,7 +24,7 @@ func generateIPv4Hostname(ipv4 net.IP) (hostname string) {
return string(hnData) return string(hnData)
} }
// generateIPv6Hostname generates the hostname for specific IP version. // generateIPv6Hostname generates the hostname by IP address version 6.
func generateIPv6Hostname(ipv6 net.IP) (hostname string) { func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
hnData := make([]byte, 0, ipv6HostnameMaxLen) hnData := make([]byte, 0, ipv6HostnameMaxLen)
for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ { for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ {
@ -51,12 +51,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) {
// //
// ff80-f076-0000-0000-0000-0000-0000-0010 // ff80-f076-0000-0000-0000-0000-0000-0010
// //
// ip must be either an IPv4 or an IPv6.
func GenerateHostname(ip net.IP) (hostname string) { func GenerateHostname(ip net.IP) (hostname string) {
if ipv4 := ip.To4(); ipv4 != nil { if ipv4 := ip.To4(); ipv4 != nil {
return generateIPv4Hostname(ipv4) return generateIPv4Hostname(ipv4)
} else if ipv6 := ip.To16(); ipv6 != nil {
return generateIPv6Hostname(ipv6)
} }
return "" return generateIPv6Hostname(ip)
} }

View File

@ -8,41 +8,57 @@ import (
) )
func TestGenerateHostName(t *testing.T) { func TestGenerateHostName(t *testing.T) {
testCases := []struct { t.Run("valid", func(t *testing.T) {
name string testCases := []struct {
want string name string
ip net.IP want string
}{{ ip net.IP
name: "good_ipv4", }{{
want: "127-0-0-1", name: "good_ipv4",
ip: net.IP{127, 0, 0, 1}, want: "127-0-0-1",
}, { ip: net.IP{127, 0, 0, 1},
name: "bad_ipv4", }, {
want: "", name: "good_ipv6",
ip: net.IP{127, 0, 0, 1, 0}, want: "fe00-0000-0000-0000-0000-0000-0000-0001",
}, { ip: net.ParseIP("fe00::1"),
name: "good_ipv6", }, {
want: "fe00-0000-0000-0000-0000-0000-0000-0001", name: "4to6",
ip: net.ParseIP("fe00::1"), want: "1-2-3-4",
}, { ip: net.ParseIP("::ffff:1.2.3.4"),
name: "bad_ipv6", }}
want: "",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
want: "",
ip: nil,
}}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
hostname := GenerateHostname(tc.ip) hostname := GenerateHostname(tc.ip)
assert.Equal(t, tc.want, hostname) assert.Equal(t, tc.want, hostname)
}) })
} }
})
t.Run("invalid", func(t *testing.T) {
testCases := []struct {
name string
ip net.IP
}{{
name: "bad_ipv4",
ip: net.IP{127, 0, 0, 1, 0},
}, {
name: "bad_ipv6",
ip: net.IP{
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff,
},
}, {
name: "nil",
ip: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Panics(t, func() { GenerateHostname(tc.ip) })
})
}
})
} }

View File

@ -15,13 +15,17 @@ import (
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
) )
// aghosRunCommand is the function to run shell commands. It's an unexported // Variables and functions to substitute in tests.
// variable instead of a direct call to make it substitutable in tests. var (
var aghosRunCommand = aghos.RunCommand // aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
// rootDirFS is the filesystem pointing to the root directory. It's an // netInterfaces is the function to get the available network interfaces.
// unexported variable instead to make it substitutable in tests. netInterfaceAddrs = net.InterfaceAddrs
var rootDirFS = aghos.RootDirFS()
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = aghos.RootDirFS()
)
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
// the IP being static is available. // the IP being static is available.
@ -65,23 +69,6 @@ func GatewayIP(ifaceName string) (ip net.IP) {
return net.ParseIP(string(fields[2])) return net.ParseIP(string(fields[2]))
} }
// CanBindPort checks if we can bind to the given port.
func CanBindPort(port int) (can bool, err error) {
var addr *net.TCPAddr
addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return false, err
}
var listener *net.TCPListener
listener, err = net.ListenTCP("tcp", addr)
if err != nil {
return false, err
}
_ = listener.Close()
return true, nil
}
// CanBindPrivilegedPorts checks if current process can bind to privileged // CanBindPrivilegedPorts checks if current process can bind to privileged
// ports. // ports.
func CanBindPrivilegedPorts() (can bool, err error) { func CanBindPrivilegedPorts() (can bool, err error) {
@ -100,8 +87,8 @@ type NetInterface struct {
MTU int `json:"mtu"` MTU int `json:"mtu"`
} }
// MarshalJSON implements the json.Marshaler interface for NetInterface. // MarshalText implements the json.Marshaler interface for NetInterface.
func (iface NetInterface) MarshalJSON() ([]byte, error) { func (iface NetInterface) MarshalText() ([]byte, error) {
type netInterface NetInterface type netInterface NetInterface
return json.Marshal(&struct { return json.Marshal(&struct {
HardwareAddr string `json:"hardware_address"` HardwareAddr string `json:"hardware_address"`
@ -114,9 +101,12 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
}) })
} }
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only // GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// we do not return link-local addresses here // WEB only we do not return link-local addresses here.
func GetValidNetInterfacesForWeb() (netInterfaces []*NetInterface, err error) { //
// TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err) return nil, fmt.Errorf("couldn't get interfaces: %w", err)
@ -157,14 +147,16 @@ func GetValidNetInterfacesForWeb() (netInterfaces []*NetInterface, err error) {
// Discard interfaces with no addresses. // Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 { if len(netIface.Addresses) != 0 {
netInterfaces = append(netInterfaces, netIface) netIfaces = append(netIfaces, netIface)
} }
} }
return netInterfaces, nil return netIfaces, nil
} }
// GetInterfaceByIP returns the name of interface containing provided ip. // GetInterfaceByIP returns the name of interface containing provided ip.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetInterfaceByIP(ip net.IP) string { func GetInterfaceByIP(ip net.IP) string {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
@ -184,6 +176,8 @@ func GetInterfaceByIP(ip net.IP) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if // GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails. // the search fails.
//
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet { func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb() netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
@ -234,29 +228,21 @@ func IsAddrInUse(err error) (ok bool) {
// CollectAllIfacesAddrs returns the slice of all network interfaces IP // CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number. // addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) { func CollectAllIfacesAddrs() (addrs []string, err error) {
var ifaces []net.Interface var ifaceAddrs []net.Addr
ifaces, err = net.Interfaces() ifaceAddrs, err = netInterfaceAddrs()
if err != nil { if err != nil {
return nil, fmt.Errorf("getting network interfaces: %w", err) return nil, fmt.Errorf("getting interfaces addresses: %w", err)
} }
for _, iface := range ifaces { for _, addr := range ifaceAddrs {
var ifaceAddrs []net.Addr cidr := addr.String()
ifaceAddrs, err = iface.Addrs() var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err) return nil, fmt.Errorf("parsing cidr: %w", err)
} }
for _, addr := range ifaceAddrs { addrs = append(addrs, ip.String())
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("parsing cidr: %w", err)
}
addrs = append(addrs, ip.String())
}
} }
return addrs, nil return addrs, nil

View File

@ -24,10 +24,6 @@ type hardwarePortInfo struct {
static bool static bool
} }
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
portInfo, err := getCurrentHardwarePortInfo(ifaceName) portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil { if err != nil {

View File

@ -13,10 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
) )
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
const rcConfFilename = "etc/rc.conf" const rcConfFilename = "etc/rc.conf"

View File

@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestRcConfStaticConfig(t *testing.T) { func TestIfaceHasStaticIP(t *testing.T) {
const ( const (
ifaceName = `em0` ifaceName = `em0`
rcConf = "etc/rc.conf" rcConf = "etc/rc.conf"

View File

@ -13,10 +13,28 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/maybe" "github.com/google/renameio/maybe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_IS_SET,
unix.CAP_NET_BIND_SERVICE,
0,
0,
)
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
}
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP. // have a static IP.
func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) { func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) {
@ -89,7 +107,7 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
filename string filename string
}{{ }{{
FileWalker: iface.dhcpcdStaticConfig, FileWalker: iface.dhcpcdStaticConfig,
filename: "etc/dhcpcd.conf", filename: dhcpcdConf,
}, { }, {
FileWalker: iface.ifacesStaticConfig, FileWalker: iface.ifacesStaticConfig,
filename: "etc/network/interfaces", filename: "etc/network/interfaces",
@ -105,14 +123,6 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, ErrNoStaticIPInfo return false, ErrNoStaticIPInfo
} }
func canBindPrivilegedPorts() (can bool, err error) {
cnbs, err := unix.PrctlRetInt(unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, unix.CAP_NET_BIND_SERVICE, 0, 0)
// Don't check the error because it's always nil on Linux.
adm, _ := aghos.HaveAdminRights()
return cnbs == 1 || adm, err
}
// findIfaceLine scans s until it finds the line that declares an interface with // findIfaceLine scans s until it finds the line that declares an interface with
// the given name. If findIfaceLine can't find the line, it returns false. // the given name. If findIfaceLine can't find the line, it returns false.
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
@ -128,25 +138,23 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
} }
// ifaceSetStaticIP configures the system to retain its current IP on the // ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpdc.conf. // interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName) ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil { if ipNet.IP == nil {
return errors.Error("can't get IP address") return errors.Error("can't get IP address")
} }
gatewayIP := GatewayIP(ifaceName) body, err := os.ReadFile(dhcpcdConf)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP)
const filename = "/etc/dhcpcd.conf"
body, err := os.ReadFile(filename)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
} }
gatewayIP := GatewayIP(ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
body = append(body, []byte(add)...) body = append(body, []byte(add)...)
err = maybe.WriteFile(filename, body, 0o644) err = maybe.WriteFile(dhcpcdConf, body, 0o644)
if err != nil { if err != nil {
return fmt.Errorf("writing conf: %w", err) return fmt.Errorf("writing conf: %w", err)
} }
@ -156,22 +164,24 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that // dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP. // configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gatewayIP, dnsIP net.IP) (conf string) { func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) {
var body []byte b := &strings.Builder{}
stringutil.WriteToBuilder(
add := fmt.Sprintf( b,
"\n# %[1]s added by AdGuard Home.\ninterface %[1]s\nstatic ip_address=%s\n", "\n# ",
ifaceName, ifaceName,
ipNet) " added by AdGuard Home.\ninterface ",
body = append(body, []byte(add)...) ifaceName,
"\nstatic ip_address=",
ipNet.String(),
"\n",
)
if gatewayIP != nil { if gwIP != nil {
add = fmt.Sprintf("static routers=%s\n", gatewayIP) stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n")
body = append(body, []byte(add)...)
} }
add = fmt.Sprintf("static domain_name_servers=%s\n\n", dnsIP) stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n")
body = append(body, []byte(add)...)
return string(body) return b.String()
} }

View File

@ -5,7 +5,6 @@ package aghnet
import ( import (
"io/fs" "io/fs"
"net"
"testing" "testing"
"testing/fstest" "testing/fstest"
@ -126,38 +125,3 @@ func TestHasStaticIP(t *testing.T) {
}) })
} }
} }
func TestSetStaticIP_dhcpcdConfIface(t *testing.T) {
testCases := []struct {
name string
dhcpcdConf string
routers net.IP
}{{
name: "with_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static routers=192.168.0.1` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: net.IP{192, 168, 0, 1},
}, {
name: "without_gateway",
dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl +
`interface wlan0` + nl +
`static ip_address=192.168.0.2/24` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl,
routers: nil,
}}
ipNet := &net.IPNet{
IP: net.IP{192, 168, 0, 2},
Mask: net.IPMask{255, 255, 255, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := dhcpcdConfIface("wlan0", ipNet, tc.routers, net.IP{192, 168, 0, 2})
assert.Equal(t, tc.dhcpcdConf, s)
})
}
}

View File

@ -0,0 +1,10 @@
//go:build !linux
// +build !linux
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}

View File

@ -13,10 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
) )
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
filename := fmt.Sprintf("etc/hostname.%s", ifaceName) filename := fmt.Sprintf("etc/hostname.%s", ifaceName)

View File

@ -56,7 +56,7 @@ type mapShell map[string]struct {
code int code int
} }
// theOnlyCmd returns s that only handles a single command and arguments // theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd. // combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}} return mapShell{cmd: {code: code, out: out, err: err}}
@ -73,18 +73,34 @@ func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err
return ret.code, []byte(ret.out), ret.err return ret.code, []byte(ret.out), ret.err
} }
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) { func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct { testCases := []struct {
name string name string
shell mapShell shell mapShell
want net.IP want net.IP
}{{ }{{
name: "success_v4", name: "success_v4",
shell: theOnlyCmd("ip route show dev ifaceName", 0, `default via 1.2.3.4 onlink`, nil), shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(), want: net.IP{1, 2, 3, 4}.To16(),
}, { }, {
name: "success_v6", name: "success_v6",
shell: theOnlyCmd("ip route show dev ifaceName", 0, `default via ::ffff onlink`, nil), shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{ want: net.IP{
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
@ -93,15 +109,15 @@ func TestGatewayIP(t *testing.T) {
}, },
}, { }, {
name: "bad_output", name: "bad_output",
shell: theOnlyCmd("ip route show dev ifaceName", 0, `non-default via 1.2.3.4 onlink`, nil), shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil, want: nil,
}, { }, {
name: "err_runcmd", name: "err_runcmd",
shell: theOnlyCmd("ip route show dev ifaceName", 0, "", errors.Error("can't run command")), shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil, want: nil,
}, { }, {
name: "bad_code", name: "bad_code",
shell: theOnlyCmd("ip route show dev ifaceName", 1, "", nil), shell: theOnlyCmd(cmd, 1, "", nil),
want: nil, want: nil,
}} }}
@ -109,7 +125,7 @@ func TestGatewayIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd) substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP("ifaceName")) assert.Equal(t, tc.want, GatewayIP(ifaceName))
}) })
} }
} }
@ -226,12 +242,56 @@ func TestCheckPort(t *testing.T) {
} }
func TestCollectAllIfacesAddrs(t *testing.T) { func TestCollectAllIfacesAddrs(t *testing.T) {
t.Skip("TODO(e.burkov): Substitute the net.Interfaces.") testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
addrs, err := CollectAllIfacesAddrs() for _, tc := range testCases {
require.NoError(t, err) t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
assert.NotEmpty(t, addrs) addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
} }
func TestIsAddrInUse(t *testing.T) { func TestIsAddrInUse(t *testing.T) {
@ -250,3 +310,33 @@ func TestIsAddrInUse(t *testing.T) {
assert.False(t, IsAddrInUse(anotherErr)) assert.False(t, IsAddrInUse(anotherErr))
}) })
} }
func TestNetInterface_MarshalText(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}`
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen)
iface := &NetInterface{
Addresses: []net.IP{ip4, ip6},
Subnets: []*net.IPNet{{
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
testutil.AssertMarshalText(t, want, iface)
}

View File

@ -13,10 +13,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func canBindPrivilegedPorts() (can bool, err error) {
return aghos.HaveAdminRights()
}
func ifaceHasStaticIP(string) (ok bool, err error) { func ifaceHasStaticIP(string) (ok bool, err error) {
return false, aghos.Unsupported("checking static ip") return false, aghos.Unsupported("checking static ip")
} }

View File

@ -517,27 +517,15 @@ func StartMods() error {
func checkPermissions() { func checkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions") log.Info("Checking if AdGuard Home has necessary permissions")
if runtime.GOOS == "windows" { if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
// On Windows we need to have admin rights to run properly
admin, _ := aghos.HaveAdminRights()
if admin {
return
}
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.") log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
} }
// We should check if AdGuard Home is able to bind to port 53 // We should check if AdGuard Home is able to bind to port 53
ok, err := aghnet.CanBindPort(53) err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS)
if err != nil {
if ok { if errors.Is(err, os.ErrPermission) {
log.Info("AdGuard Home can bind to port 53") log.Fatal(`Permission check failed.
return
}
if errors.Is(err, os.ErrPermission) {
msg := `Permission check failed.
AdGuard Home is not allowed to bind to privileged ports (for instance, port 53). AdGuard Home is not allowed to bind to privileged ports (for instance, port 53).
Please note, that this is crucial for a server to be able to use privileged ports. Please note, that this is crucial for a server to be able to use privileged ports.
@ -545,16 +533,17 @@ Please note, that this is crucial for a server to be able to use privileged port
You have two options: You have two options:
1. Run AdGuard Home with root privileges 1. Run AdGuard Home with root privileges
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser` https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`)
}
log.Fatal(msg) log.Info(
"AdGuard failed to bind to port 53: %s\n\n"+
"Please note, that this is crucial for a DNS server to be able to use that port.",
err,
)
} }
msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v log.Info("AdGuard Home can bind to port 53")
Please note, that this is crucial for a DNS server to be able to use that port.`, err)
log.Info(msg)
} }
// Write PID to a file // Write PID to a file

View File

@ -134,9 +134,10 @@ underscores() {
-e '_bsd.go'\ -e '_bsd.go'\
-e '_darwin.go'\ -e '_darwin.go'\
-e '_freebsd.go'\ -e '_freebsd.go'\
-e '_openbsd.go'\
-e '_linux.go'\ -e '_linux.go'\
-e '_little.go'\ -e '_little.go'\
-e '_nolinux.go'\
-e '_openbsd.go'\
-e '_others.go'\ -e '_others.go'\
-e '_test.go'\ -e '_test.go'\
-e '_unix.go'\ -e '_unix.go'\