diff --git a/CHANGELOG.md b/CHANGELOG.md index df32b0f8..3b243cae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,10 +34,6 @@ and this project adheres to ### Changed -- Reverse DNS now has a greater priority as the source of runtime clients' - information than ARP neighborhood. -- Improved detection of runtime clients through more resilient ARP processing - ([#3597]). - The TTL of responses served from the optimistic cache is now lowered to 10 seconds. - Domain-specific private reverse DNS upstream servers are now validated to @@ -114,7 +110,6 @@ In this release, the schema version has changed from 12 to 14. [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 [#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 -[#3597]: https://github.com/AdguardTeam/AdGuardHome/issues/3597 [#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238 [ddr-draft-06]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html @@ -150,6 +145,10 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. ### Changed +- Reverse DNS now has a greater priority as the source of runtime clients' + information than ARP neighborhood. +- Improved detection of runtime clients through more resilient ARP processing + ([#3597]). - On OpenBSD, the daemon script now uses the recommended `/bin/ksh` shell instead of the `/bin/sh` one ([#4533]). To apply this change, backup your data and run `AdGuardHome -s uninstall && AdGuardHome -s install`. @@ -169,6 +168,7 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. [#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730 [#3157]: https://github.com/AdguardTeam/AdGuardHome/issues/3157 +[#3597]: https://github.com/AdguardTeam/AdGuardHome/issues/3597 [#3978]: https://github.com/AdguardTeam/AdGuardHome/issues/3978 [#4166]: https://github.com/AdguardTeam/AdGuardHome/issues/4166 [#4213]: https://github.com/AdguardTeam/AdGuardHome/issues/4213 diff --git a/internal/aghnet/arpdb.go b/internal/aghnet/arpdb.go new file mode 100644 index 00000000..4909af5f --- /dev/null +++ b/internal/aghnet/arpdb.go @@ -0,0 +1,211 @@ +package aghnet + +import ( + "bufio" + "bytes" + "fmt" + "net" + "sync" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" +) + +// ARPDB: The Network Neighborhood Database + +// ARPDB stores and refreshes the network neighborhood reported by ARP (Address +// Resolution Protocol). +type ARPDB interface { + // Refresh updates the stored data. It must be safe for concurrent use. + Refresh() (err error) + + // Neighbors returnes the last set of data reported by ARP. Both the method + // and it's result must be safe for concurrent use. + Neighbors() (ns []Neighbor) +} + +// NewARPDB returns the ARPDB properly initialized for the OS. +func NewARPDB() (arp ARPDB) { + return newARPDB() +} + +// Empty ARPDB implementation + +// EmptyARPDB is the ARPDB implementation that does nothing. +type EmptyARPDB struct{} + +// type check +var _ ARPDB = EmptyARPDB{} + +// Refresh implements the ARPDB interface for EmptyARPContainer. It does +// nothing and always returns nil error. +func (EmptyARPDB) Refresh() (err error) { return nil } + +// Neighbors implements the ARPDB interface for EmptyARPContainer. It always +// returns nil. +func (EmptyARPDB) Neighbors() (ns []Neighbor) { return nil } + +// ARPDB Helper Types + +// Neighbor is the pair of IP address and MAC address reported by ARP. +type Neighbor struct { + // Name is the hostname of the neighbor. Empty name is valid since not each + // implementation of ARP is able to retrieve that. + Name string + + // IP contains either IPv4 or IPv6. + IP net.IP + + // MAC contains the hardware address. + MAC net.HardwareAddr +} + +// Clone returns the deep copy of n. +func (n Neighbor) Clone() (clone Neighbor) { + return Neighbor{ + Name: n.Name, + IP: netutil.CloneIP(n.IP), + MAC: netutil.CloneMAC(n.MAC), + } +} + +// neighs is the helper type that stores neighbors to avoid copying its methods +// among all the ARPDB implementations. +type neighs struct { + mu *sync.RWMutex + ns []Neighbor +} + +// len returns the length of the neighbors slice. It's safe for concurrent use. +func (ns *neighs) len() (l int) { + ns.mu.RLock() + defer ns.mu.RUnlock() + + return len(ns.ns) +} + +// clone returns a deep copy of the underlying neighbors slice. It's safe for +// concurrent use. +func (ns *neighs) clone() (cloned []Neighbor) { + ns.mu.RLock() + defer ns.mu.RUnlock() + + cloned = make([]Neighbor, len(ns.ns)) + for i, n := range ns.ns { + cloned[i] = n.Clone() + } + + return cloned +} + +// reset replaces the underlying slice with the new one. It's safe for +// concurrent use. +func (ns *neighs) reset(with []Neighbor) { + ns.mu.Lock() + defer ns.mu.Unlock() + + ns.ns = with +} + +// Command ARPDB + +// parseNeighsFunc parses the text from sc as if it'd be an output of some +// ARP-related command. lenHint is a hint for the size of the allocated slice +// of Neighbors. +type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor) + +// cmdARPDB is the implementation of the ARPDB that uses command line to +// retrieve data. +type cmdARPDB struct { + parse parseNeighsFunc + ns *neighs + cmd string + args []string +} + +// type check +var _ ARPDB = (*cmdARPDB)(nil) + +// Refresh implements the ARPDB interface for *cmdARPDB. +func (arp *cmdARPDB) Refresh() (err error) { + defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }() + + code, out, err := aghosRunCommand(arp.cmd, arp.args...) + if err != nil { + return fmt.Errorf("running command: %w", err) + } else if code != 0 { + return fmt.Errorf("running command: unexpected exit code %d", code) + } + + sc := bufio.NewScanner(bytes.NewReader(out)) + ns := arp.parse(sc, arp.ns.len()) + if err = sc.Err(); err != nil { + // TODO(e.burkov): This error seems unreachable. Investigate. + return fmt.Errorf("scanning the output: %w", err) + } + + arp.ns.reset(ns) + + return nil +} + +// Neighbors implements the ARPDB interface for *cmdARPDB. +func (arp *cmdARPDB) Neighbors() (ns []Neighbor) { + return arp.ns.clone() +} + +// Composite ARPDB + +// arpdbs is the ARPDB that combines several ARPDB implementations and +// consequently switches between those. +type arpdbs struct { + // arps is the set of ARPDB implementations to range through. + arps []ARPDB + neighs +} + +// newARPDBs returns a properly initialized *arpdbs. It begins refreshing from +// the first of arps. +func newARPDBs(arps ...ARPDB) (arp *arpdbs) { + return &arpdbs{ + arps: arps, + neighs: neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } +} + +// type check +var _ ARPDB = (*arpdbs)(nil) + +// Refresh implements the ARPDB interface for *arpdbs. +func (arp *arpdbs) Refresh() (err error) { + var errs []error + + for _, a := range arp.arps { + err = a.Refresh() + if err != nil { + errs = append(errs, err) + + continue + } + + arp.reset(a.Neighbors()) + + return nil + } + + if len(errs) > 0 { + err = errors.List("each arpdb failed", errs...) + } + + return err +} + +// Neighbors implements the ARPDB interface for *arpdbs. +// +// TODO(e.burkov): Think of a way to avoid cloning the slice twice. +func (arp *arpdbs) Neighbors() (ns []Neighbor) { + return arp.clone() +} diff --git a/internal/aghnet/arpdb_bsd.go b/internal/aghnet/arpdb_bsd.go new file mode 100644 index 00000000..26ac7758 --- /dev/null +++ b/internal/aghnet/arpdb_bsd.go @@ -0,0 +1,76 @@ +//go:build darwin || freebsd +// +build darwin freebsd + +package aghnet + +import ( + "bufio" + "net" + "strings" + "sync" + + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" +) + +func newARPDB() (arp *cmdARPDB) { + return &cmdARPDB{ + parse: parseArpA, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + cmd: "arp", + // Use -n flag to avoid resolving the hostnames of the neighbors. By + // default ARP attempts to resolve the hostnames via DNS. See man 8 + // arp. + // + // See also https://github.com/AdguardTeam/AdGuardHome/issues/3157. + args: []string{"-a", "-n"}, + } +} + +// parseArpA parses the output of the "arp -a -n" command on macOS and FreeBSD. +// The expected input format: +// +// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet] +// +func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + + fields := strings.Fields(ln) + if len(fields) < 4 { + continue + } + + n := Neighbor{} + + if ipStr := fields[1]; len(ipStr) < 2 { + continue + } else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil { + continue + } else { + n.IP = ip + } + + hwStr := fields[3] + if mac, err := net.ParseMAC(hwStr); err != nil { + continue + } else { + n.MAC = mac + } + + host := fields[0] + if err := netutil.ValidateDomainName(host); err != nil { + log.Debug("parsing arp output: %s", err) + } else { + n.Name = host + } + + ns = append(ns, n) + } + + return ns +} diff --git a/internal/aghnet/arpdb_bsd_test.go b/internal/aghnet/arpdb_bsd_test.go new file mode 100644 index 00000000..3404af69 --- /dev/null +++ b/internal/aghnet/arpdb_bsd_test.go @@ -0,0 +1,31 @@ +//go:build darwin || freebsd +// +build darwin freebsd + +package aghnet + +import ( + "net" +) + +const arpAOutput = ` +invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet] +invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet] +invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet] +hostname.one (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] +hostname.two (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 1198 seconds [ethernet] +? (::1234) at aa:bb:cc:dd:ee:ff on ej0 expires in 1918 seconds [ethernet] +` + +var wantNeighs = []Neighbor{{ + Name: "hostname.one", + IP: net.IPv4(192, 168, 1, 2), + MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, +}, { + Name: "hostname.two", + IP: net.ParseIP("::ffff:ffff"), + MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, +}, { + Name: "", + IP: net.ParseIP("::1234"), + MAC: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, +}} diff --git a/internal/aghnet/arpdb_linux.go b/internal/aghnet/arpdb_linux.go new file mode 100644 index 00000000..9cc38906 --- /dev/null +++ b/internal/aghnet/arpdb_linux.go @@ -0,0 +1,243 @@ +//go:build linux +// +build linux + +package aghnet + +import ( + "bufio" + "fmt" + "io/fs" + "net" + "strings" + "sync" + + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/stringutil" +) + +func newARPDB() (arp *arpdbs) { + // Use the common storage among the implementations. + ns := &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + } + + var parseF parseNeighsFunc + if aghos.IsOpenWrt() { + parseF = parseArpAWrt + } else { + parseF = parseArpA + } + + return newARPDBs( + // Try /proc/net/arp first. + &fsysARPDB{ + ns: ns, + fsys: rootDirFS, + filename: "proc/net/arp", + }, + // Then, try "arp -a -n". + &cmdARPDB{ + parse: parseF, + ns: ns, + cmd: "arp", + // Use -n flag to avoid resolving the hostnames of the neighbors. + // By default ARP attempts to resolve the hostnames via DNS. See + // man 8 arp. + // + // See also https://github.com/AdguardTeam/AdGuardHome/issues/3157. + args: []string{"-a", "-n"}, + }, + // Finally, try "ip neigh". + &cmdARPDB{ + parse: parseIPNeigh, + ns: ns, + cmd: "ip", + args: []string{"neigh"}, + }, + ) +} + +// fsysARPDB accesses the ARP cache file to update the database. +type fsysARPDB struct { + ns *neighs + fsys fs.FS + filename string +} + +// type check +var _ ARPDB = (*fsysARPDB)(nil) + +// Refresh implements the ARPDB interface for *fsysARPDB. +func (arp *fsysARPDB) Refresh() (err error) { + var f fs.File + f, err = arp.fsys.Open(arp.filename) + if err != nil { + return fmt.Errorf("opening %q: %w", arp.filename, err) + } + + sc := bufio.NewScanner(f) + // Skip the header. + if !sc.Scan() { + return nil + } else if err = sc.Err(); err != nil { + return err + } + + ns := make([]Neighbor, 0, arp.ns.len()) + for sc.Scan() { + ln := sc.Text() + fields := stringutil.SplitTrimmed(ln, " ") + if len(fields) != 6 { + continue + } + + n := Neighbor{} + if n.IP = net.ParseIP(fields[0]); n.IP == nil || n.IP.IsUnspecified() { + continue + } else if n.MAC, err = net.ParseMAC(fields[3]); err != nil { + continue + } + + ns = append(ns, n) + } + + arp.ns.reset(ns) + + return nil +} + +// Neighbors implements the ARPDB interface for *fsysARPDB. +func (arp *fsysARPDB) Neighbors() (ns []Neighbor) { + return arp.ns.clone() +} + +// parseArpAWrt parses the output of the "arp -a -n" command on OpenWrt. The +// expected input format: +// +// IP address HW type Flags HW address Mask Device +// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan +// +func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + if !sc.Scan() { + // Skip the header. + return + } + + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + + fields := strings.Fields(ln) + if len(fields) < 4 { + continue + } + + n := Neighbor{} + + if ip := net.ParseIP(fields[0]); ip == nil || n.IP.IsUnspecified() { + continue + } else { + n.IP = ip + } + + hwStr := fields[3] + if mac, err := net.ParseMAC(hwStr); err != nil { + log.Debug("parsing arp output: %s", err) + + continue + } else { + n.MAC = mac + } + + ns = append(ns, n) + } + + return ns +} + +// parseArpA parses the output of the "arp -a -n" command on Linux. The +// expected input format: +// +// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3 +// +func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + + fields := strings.Fields(ln) + if len(fields) < 4 { + continue + } + + n := Neighbor{} + + if ipStr := fields[1]; len(ipStr) < 2 { + continue + } else if ip := net.ParseIP(ipStr[1 : len(ipStr)-1]); ip == nil { + continue + } else { + n.IP = ip + } + + hwStr := fields[3] + if mac, err := net.ParseMAC(hwStr); err != nil { + log.Debug("parsing arp output: %s", err) + + continue + } else { + n.MAC = mac + } + + host := fields[0] + if verr := netutil.ValidateDomainName(host); verr != nil { + log.Debug("parsing arp output: %s", verr) + } else { + n.Name = host + } + + ns = append(ns, n) + } + + return ns +} + +// parseIPNeigh parses the output of the "ip neigh" command on Linux. The +// expected input format: +// +// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE +// +func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + + fields := strings.Fields(ln) + if len(fields) < 5 { + continue + } + + n := Neighbor{} + + if ip := net.ParseIP(fields[0]); ip == nil { + continue + } else { + n.IP = ip + } + + if mac, err := net.ParseMAC(fields[4]); err != nil { + log.Debug("parsing arp output: %s", err) + + continue + } else { + n.MAC = mac + } + + ns = append(ns, n) + } + + return ns +} diff --git a/internal/aghnet/arpdb_linux_test.go b/internal/aghnet/arpdb_linux_test.go new file mode 100644 index 00000000..46d87150 --- /dev/null +++ b/internal/aghnet/arpdb_linux_test.go @@ -0,0 +1,102 @@ +//go:build linux +// +build linux + +package aghnet + +import ( + "net" + "sync" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const arpAOutputWrt = ` +IP address HW type Flags HW address Mask Device +1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan +1.2.3.4 0x1 0x2 12:34:56:78:910 * wan +192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan +::ffff:ffff 0x1 0x2 ef:cd:ab:ef:cd:ab * wan` + +const arpAOutput = ` +invalid.mac (1.2.3.4) at 12:34:56:78:910 on el0 ifscope [ethernet] +invalid.ip (1.2.3.4.5) at ab:cd:ef:ab:cd:12 on ek0 ifscope [ethernet] +invalid.fmt 1 at 12:cd:ef:ab:cd:ef on er0 ifscope [ethernet] +? (192.168.1.2) at ab:cd:ef:ab:cd:ef on en0 ifscope [ethernet] +? (::ffff:ffff) at ef:cd:ab:ef:cd:ab on em0 expires in 100 seconds [ethernet]` + +const ipNeighOutput = ` +1.2.3.4.5 dev enp0s3 lladdr aa:bb:cc:dd:ee:ff DELAY +1.2.3.4 dev enp0s3 lladdr 12:34:56:78:910 DELAY +192.168.1.2 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef DELAY +::ffff:ffff dev enp0s3 lladdr ef:cd:ab:ef:cd:ab router STALE` + +var wantNeighs = []Neighbor{{ + IP: net.IPv4(192, 168, 1, 2), + MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, +}, { + IP: net.ParseIP("::ffff:ffff"), + MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, +}} + +func TestFSysARPDB(t *testing.T) { + require.NoError(t, fstest.TestFS(testdata, "proc_net_arp")) + + a := &fsysARPDB{ + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + fsys: testdata, + filename: "proc_net_arp", + } + + err := a.Refresh() + require.NoError(t, err) + + ns := a.Neighbors() + assert.Equal(t, wantNeighs, ns) +} + +func TestCmdARPDB_linux(t *testing.T) { + sh := mapShell{ + "arp -a": {err: nil, out: arpAOutputWrt, code: 0}, + "ip neigh": {err: nil, out: ipNeighOutput, code: 0}, + } + substShell(t, sh.RunCmd) + + t.Run("wrt", func(t *testing.T) { + a := &cmdARPDB{ + parse: parseArpAWrt, + cmd: "arp", + args: []string{"-a"}, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } + + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, wantNeighs, a.Neighbors()) + }) + + t.Run("ip_neigh", func(t *testing.T) { + a := &cmdARPDB{ + parse: parseIPNeigh, + cmd: "ip", + args: []string{"neigh"}, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, wantNeighs, a.Neighbors()) + }) +} diff --git a/internal/aghnet/arpdb_openbsd.go b/internal/aghnet/arpdb_openbsd.go new file mode 100644 index 00000000..d5ec5fea --- /dev/null +++ b/internal/aghnet/arpdb_openbsd.go @@ -0,0 +1,73 @@ +//go:build openbsd +// +build openbsd + +package aghnet + +import ( + "bufio" + "net" + "strings" + "sync" + + "github.com/AdguardTeam/golibs/log" +) + +func newARPDB() (arp *cmdARPDB) { + return &cmdARPDB{ + parse: parseArpA, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + cmd: "arp", + // Use -n flag to avoid resolving the hostnames of the neighbors. By + // default ARP attempts to resolve the hostnames via DNS. See man 8 + // arp. + // + // See also https://github.com/AdguardTeam/AdGuardHome/issues/3157. + args: []string{"-a", "-n"}, + } +} + +// parseArpA parses the output of the "arp -a -n" command on OpenBSD. The +// expected input format: +// +// Host Ethernet Address Netif Expire Flags +// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s +// +func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + // Skip the header. + if !sc.Scan() { + return nil + } + + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + + fields := strings.Fields(ln) + if len(fields) < 2 { + continue + } + + n := Neighbor{} + + if ip := net.ParseIP(fields[0]); ip == nil { + continue + } else { + n.IP = ip + } + + if mac, err := net.ParseMAC(fields[1]); err != nil { + log.Debug("parsing arp output: %s", err) + + continue + } else { + n.MAC = mac + } + + ns = append(ns, n) + } + + return ns +} diff --git a/internal/aghnet/arpdb_openbsd_test.go b/internal/aghnet/arpdb_openbsd_test.go new file mode 100644 index 00000000..915c17ff --- /dev/null +++ b/internal/aghnet/arpdb_openbsd_test.go @@ -0,0 +1,24 @@ +//go:build openbsd +// +build openbsd + +package aghnet + +import ( + "net" +) + +const arpAOutput = ` +Host Ethernet Address Netif Expire Flags +1.2.3.4.5 aa:bb:cc:dd:ee:ff em0 permanent +1.2.3.4 12:34:56:78:910 em0 permanent +192.168.1.2 ab:cd:ef:ab:cd:ef em0 19m56s +::ffff:ffff ef:cd:ab:ef:cd:ab em0 permanent l +` + +var wantNeighs = []Neighbor{{ + IP: net.IPv4(192, 168, 1, 2), + MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, +}, { + IP: net.ParseIP("::ffff:ffff"), + MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, +}} diff --git a/internal/aghnet/arpdb_test.go b/internal/aghnet/arpdb_test.go new file mode 100644 index 00000000..d6971448 --- /dev/null +++ b/internal/aghnet/arpdb_test.go @@ -0,0 +1,216 @@ +package aghnet + +import ( + "net" + "sync" + "testing" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewARPDB(t *testing.T) { + var a ARPDB + require.NotPanics(t, func() { a = NewARPDB() }) + + assert.NotNil(t, a) +} + +// TestARPDB is the mock implementation of ARPDB to use in tests. +type TestARPDB struct { + OnRefresh func() (err error) + OnNeighbors func() (ns []Neighbor) +} + +// Refresh implements the ARPDB interface for *TestARPDB. +func (arp *TestARPDB) Refresh() (err error) { + return arp.OnRefresh() +} + +// Neighbors implements the ARPDB interface for *TestARPDB. +func (arp *TestARPDB) Neighbors() (ns []Neighbor) { + return arp.OnNeighbors() +} + +func TestARPDBS(t *testing.T) { + knownIP := net.IP{1, 2, 3, 4} + knownMAC := net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF} + + succRefrCount, failRefrCount := 0, 0 + clnp := func() { + succRefrCount, failRefrCount = 0, 0 + } + + succDB := &TestARPDB{ + OnRefresh: func() (err error) { succRefrCount++; return nil }, + OnNeighbors: func() (ns []Neighbor) { + return []Neighbor{{Name: "abc", IP: knownIP, MAC: knownMAC}} + }, + } + failDB := &TestARPDB{ + OnRefresh: func() (err error) { failRefrCount++; return errors.Error("refresh failed") }, + OnNeighbors: func() (ns []Neighbor) { return nil }, + } + + t.Run("begin_with_success", func(t *testing.T) { + t.Cleanup(clnp) + + a := newARPDBs(succDB, failDB) + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, 1, succRefrCount) + assert.Zero(t, failRefrCount) + assert.NotEmpty(t, a.Neighbors()) + }) + + t.Run("begin_with_fail", func(t *testing.T) { + t.Cleanup(clnp) + + a := newARPDBs(failDB, succDB) + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, 1, succRefrCount) + assert.Equal(t, 1, failRefrCount) + assert.NotEmpty(t, a.Neighbors()) + }) + + t.Run("fail_only", func(t *testing.T) { + t.Cleanup(clnp) + + wantMsg := `each arpdb failed: 2 errors: "refresh failed", "refresh failed"` + + a := newARPDBs(failDB, failDB) + err := a.Refresh() + require.Error(t, err) + + testutil.AssertErrorMsg(t, wantMsg, err) + + assert.Equal(t, 2, failRefrCount) + assert.Empty(t, a.Neighbors()) + }) + + t.Run("fail_after_success", func(t *testing.T) { + t.Cleanup(clnp) + + shouldFail := false + unstableDB := &TestARPDB{ + OnRefresh: func() (err error) { + if shouldFail { + err = errors.Error("unstable failed") + } + shouldFail = !shouldFail + + return err + }, + OnNeighbors: func() (ns []Neighbor) { + if !shouldFail { + return failDB.OnNeighbors() + } + + return succDB.OnNeighbors() + }, + } + a := newARPDBs(unstableDB, succDB) + + // Unstable ARPDB should refresh successfully. + err := a.Refresh() + require.NoError(t, err) + + assert.Zero(t, succRefrCount) + assert.NotEmpty(t, a.Neighbors()) + + // Unstable ARPDB should fail and the succDB should be used. + err = a.Refresh() + require.NoError(t, err) + + assert.Equal(t, 1, succRefrCount) + assert.NotEmpty(t, a.Neighbors()) + + // Unstable ARPDB should refresh successfully again. + err = a.Refresh() + require.NoError(t, err) + + assert.Equal(t, 1, succRefrCount) + assert.NotEmpty(t, a.Neighbors()) + }) + + t.Run("empty", func(t *testing.T) { + a := newARPDBs() + require.NoError(t, a.Refresh()) + + assert.Empty(t, a.Neighbors()) + }) +} + +func TestCmdARPDB_arpa(t *testing.T) { + a := &cmdARPDB{ + cmd: "cmd", + parse: parseArpA, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + } + + t.Run("arp_a", func(t *testing.T) { + sh := theOnlyCmd("cmd", 0, arpAOutput, nil) + substShell(t, sh.RunCmd) + + err := a.Refresh() + require.NoError(t, err) + + assert.Equal(t, wantNeighs, a.Neighbors()) + }) + + t.Run("runcmd_error", func(t *testing.T) { + sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run")) + substShell(t, sh.RunCmd) + + err := a.Refresh() + testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err) + }) + + t.Run("bad_code", func(t *testing.T) { + sh := theOnlyCmd("cmd", 1, "", nil) + substShell(t, sh.RunCmd) + + err := a.Refresh() + testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err) + }) + + t.Run("empty", func(t *testing.T) { + sh := theOnlyCmd("cmd", 0, "", nil) + substShell(t, sh.RunCmd) + + err := a.Refresh() + require.NoError(t, err) + + assert.Empty(t, a.Neighbors()) + }) +} + +func TestEmptyARPDB(t *testing.T) { + a := EmptyARPDB{} + + t.Run("refresh", func(t *testing.T) { + var err error + require.NotPanics(t, func() { + err = a.Refresh() + }) + + assert.NoError(t, err) + }) + + t.Run("neighbors", func(t *testing.T) { + var ns []Neighbor + require.NotPanics(t, func() { + ns = a.Neighbors() + }) + + assert.Empty(t, ns) + }) +} diff --git a/internal/aghnet/arpdb_windows.go b/internal/aghnet/arpdb_windows.go new file mode 100644 index 00000000..8d5431eb --- /dev/null +++ b/internal/aghnet/arpdb_windows.go @@ -0,0 +1,65 @@ +//go:build windows +// +build windows + +package aghnet + +import ( + "bufio" + "net" + "strings" + "sync" +) + +func newARPDB() (arp *cmdARPDB) { + return &cmdARPDB{ + parse: parseArpA, + ns: &neighs{ + mu: &sync.RWMutex{}, + ns: make([]Neighbor, 0), + }, + cmd: "arp", + args: []string{"/a"}, + } +} + +// parseArpA parses the output of the "arp /a" command on Windows. The expected +// input format (the first line is empty): +// +// +// Interface: 192.168.56.16 --- 0x7 +// Internet Address Physical Address Type +// 192.168.56.1 0a-00-27-00-00-00 dynamic +// 192.168.56.255 ff-ff-ff-ff-ff-ff static +// +func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) { + ns = make([]Neighbor, 0, lenHint) + for sc.Scan() { + ln := sc.Text() + if ln == "" { + continue + } + + fields := strings.Fields(ln) + if len(fields) != 3 { + continue + } + + n := Neighbor{} + + if ip := net.ParseIP(fields[0]); ip == nil { + continue + } else { + n.IP = ip + } + + if mac, err := net.ParseMAC(fields[1]); err != nil { + continue + } else { + n.MAC = mac + } + + ns = append(ns, n) + } + + return ns +} diff --git a/internal/aghnet/arpdb_windows_test.go b/internal/aghnet/arpdb_windows_test.go new file mode 100644 index 00000000..ad88ff8e --- /dev/null +++ b/internal/aghnet/arpdb_windows_test.go @@ -0,0 +1,23 @@ +//go:build windows +// +build windows + +package aghnet + +import ( + "net" +) + +const arpAOutput = ` + +Interface: 192.168.1.1 --- 0x7 + Internet Address Physical Address Type + 192.168.1.2 ab-cd-ef-ab-cd-ef dynamic + ::ffff:ffff ef-cd-ab-ef-cd-ab static` + +var wantNeighs = []Neighbor{{ + IP: net.IPv4(192, 168, 1, 2), + MAC: net.HardwareAddr{0xAB, 0xCD, 0xEF, 0xAB, 0xCD, 0xEF}, +}, { + IP: net.ParseIP("::ffff:ffff"), + MAC: net.HardwareAddr{0xEF, 0xCD, 0xAB, 0xEF, 0xCD, 0xAB}, +}} diff --git a/internal/aghnet/hostgen.go b/internal/aghnet/hostgen.go index d9278515..683c8d9f 100644 --- a/internal/aghnet/hostgen.go +++ b/internal/aghnet/hostgen.go @@ -11,7 +11,7 @@ const ( ipv6HostnameMaxLen = len("ff80-f076-0000-0000-0000-0000-0000-0010") ) -// generateIPv4Hostname generates the hostname for specific IP version. +// generateIPv4Hostname generates the hostname by IP address version 4. func generateIPv4Hostname(ipv4 net.IP) (hostname string) { hnData := make([]byte, 0, ipv4HostnameMaxLen) for i, part := range ipv4 { @@ -24,7 +24,7 @@ func generateIPv4Hostname(ipv4 net.IP) (hostname string) { return string(hnData) } -// generateIPv6Hostname generates the hostname for specific IP version. +// generateIPv6Hostname generates the hostname by IP address version 6. func generateIPv6Hostname(ipv6 net.IP) (hostname string) { hnData := make([]byte, 0, ipv6HostnameMaxLen) for i, partsNum := 0, net.IPv6len/2; i < partsNum; i++ { @@ -51,12 +51,11 @@ func generateIPv6Hostname(ipv6 net.IP) (hostname string) { // // ff80-f076-0000-0000-0000-0000-0000-0010 // +// ip must be either an IPv4 or an IPv6. func GenerateHostname(ip net.IP) (hostname string) { if ipv4 := ip.To4(); ipv4 != nil { return generateIPv4Hostname(ipv4) - } else if ipv6 := ip.To16(); ipv6 != nil { - return generateIPv6Hostname(ipv6) } - return "" + return generateIPv6Hostname(ip) } diff --git a/internal/aghnet/hostgen_test.go b/internal/aghnet/hostgen_test.go index 37121628..d37e556b 100644 --- a/internal/aghnet/hostgen_test.go +++ b/internal/aghnet/hostgen_test.go @@ -8,41 +8,57 @@ import ( ) func TestGenerateHostName(t *testing.T) { - testCases := []struct { - name string - want string - ip net.IP - }{{ - name: "good_ipv4", - want: "127-0-0-1", - ip: net.IP{127, 0, 0, 1}, - }, { - name: "bad_ipv4", - want: "", - ip: net.IP{127, 0, 0, 1, 0}, - }, { - name: "good_ipv6", - want: "fe00-0000-0000-0000-0000-0000-0000-0001", - ip: net.ParseIP("fe00::1"), - }, { - name: "bad_ipv6", - want: "", - ip: net.IP{ - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, - }, - }, { - name: "nil", - want: "", - ip: nil, - }} + t.Run("valid", func(t *testing.T) { + testCases := []struct { + name string + want string + ip net.IP + }{{ + name: "good_ipv4", + want: "127-0-0-1", + ip: net.IP{127, 0, 0, 1}, + }, { + name: "good_ipv6", + want: "fe00-0000-0000-0000-0000-0000-0000-0001", + ip: net.ParseIP("fe00::1"), + }, { + name: "4to6", + want: "1-2-3-4", + ip: net.ParseIP("::ffff:1.2.3.4"), + }} - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - hostname := GenerateHostname(tc.ip) - assert.Equal(t, tc.want, hostname) - }) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hostname := GenerateHostname(tc.ip) + assert.Equal(t, tc.want, hostname) + }) + } + }) + + t.Run("invalid", func(t *testing.T) { + testCases := []struct { + name string + ip net.IP + }{{ + name: "bad_ipv4", + ip: net.IP{127, 0, 0, 1, 0}, + }, { + name: "bad_ipv6", + ip: net.IP{ + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + }, + }, { + name: "nil", + ip: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Panics(t, func() { GenerateHostname(tc.ip) }) + }) + } + }) } diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index d6cf60cf..65c9d3c4 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -368,8 +368,8 @@ func (hp *hostsParser) addPairs(ip net.IP, hosts []string) { } } -// writeRules writes the actual rule for the qtype and the PTR for the -// host-ip pair into internal builders. +// writeRules writes the actual rule for the qtype and the PTR for the host-ip +// pair into internal builders. func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { arpa, err := netutil.IPToReversedAddr(ip) if err != nil { diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index ecb70fa8..268380bd 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -2,19 +2,31 @@ package aghnet import ( + "bytes" "encoding/json" "fmt" "io" "net" - "os/exec" - "strings" "syscall" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" ) +// Variables and functions to substitute in tests. +var ( + // aghosRunCommand is the function to run shell commands. + aghosRunCommand = aghos.RunCommand + + // netInterfaces is the function to get the available network interfaces. + netInterfaceAddrs = net.InterfaceAddrs + + // rootDirFS is the filesystem pointing to the root directory. + rootDirFS = aghos.RootDirFS() +) + // ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about // the IP being static is available. const ErrNoStaticIPInfo errors.Error = "no information about static ip" @@ -32,39 +44,29 @@ func IfaceSetStaticIP(ifaceName string) (err error) { } // GatewayIP returns IP address of interface's gateway. -func GatewayIP(ifaceName string) net.IP { - cmd := exec.Command("ip", "route", "show", "dev", ifaceName) - log.Tracef("executing %s %v", cmd.Path, cmd.Args) - d, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { +// +// TODO(e.burkov): Investigate if the gateway address may be fetched in another +// way since not every machine has the software installed. +func GatewayIP(ifaceName string) (ip net.IP) { + code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName) + if err != nil { + log.Debug("%s", err) + + return nil + } else if code != 0 { + log.Debug("fetching gateway ip: unexpected exit code: %d", code) + return nil } - fields := strings.Fields(string(d)) + fields := bytes.Fields(out) // The meaningful "ip route" command output should contain the word // "default" at first field and default gateway IP address at third field. - if len(fields) < 3 || fields[0] != "default" { + if len(fields) < 3 || string(fields[0]) != "default" { return nil } - return net.ParseIP(fields[2]) -} - -// CanBindPort checks if we can bind to the given port. -func CanBindPort(port int) (can bool, err error) { - var addr *net.TCPAddr - addr, err = net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - return false, err - } - - var listener *net.TCPListener - listener, err = net.ListenTCP("tcp", addr) - if err != nil { - return false, err - } - _ = listener.Close() - return true, nil + return net.ParseIP(string(fields[2])) } // CanBindPrivilegedPorts checks if current process can bind to privileged @@ -99,19 +101,19 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) { }) } -// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and WEB only -// we do not return link-local addresses here -func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { +// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and +// WEB only we do not return link-local addresses here. +// +// TODO(e.burkov): Can't properly test the function since it's nontrivial to +// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used. +func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) { ifaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("couldn't get interfaces: %w", err) - } - if len(ifaces) == 0 { + } else if len(ifaces) == 0 { return nil, errors.Error("couldn't find any legible interface") } - var netInterfaces []*NetInterface - for _, iface := range ifaces { var addrs []net.Addr addrs, err = iface.Addrs() @@ -131,26 +133,30 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) { ipNet, ok := addr.(*net.IPNet) if !ok { // Should be net.IPNet, this is weird. - return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr) + return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr) } + // Ignore link-local. if ipNet.IP.IsLinkLocalUnicast() { continue } + netIface.Addresses = append(netIface.Addresses, ipNet.IP) netIface.Subnets = append(netIface.Subnets, ipNet) } // Discard interfaces with no addresses. if len(netIface.Addresses) != 0 { - netInterfaces = append(netInterfaces, netIface) + netIfaces = append(netIfaces, netIface) } } - return netInterfaces, nil + return netIfaces, nil } // GetInterfaceByIP returns the name of interface containing provided ip. +// +// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. func GetInterfaceByIP(ip net.IP) string { ifaces, err := GetValidNetInterfacesForWeb() if err != nil { @@ -170,6 +176,8 @@ func GetInterfaceByIP(ip net.IP) string { // GetSubnet returns pointer to net.IPNet for the specified interface or nil if // the search fails. +// +// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. func GetSubnet(ifaceName string) *net.IPNet { netIfaces, err := GetValidNetInterfacesForWeb() if err != nil { @@ -220,29 +228,21 @@ func IsAddrInUse(err error) (ok bool) { // CollectAllIfacesAddrs returns the slice of all network interfaces IP // addresses without port number. func CollectAllIfacesAddrs() (addrs []string, err error) { - var ifaces []net.Interface - ifaces, err = net.Interfaces() + var ifaceAddrs []net.Addr + ifaceAddrs, err = netInterfaceAddrs() if err != nil { - return nil, fmt.Errorf("getting network interfaces: %w", err) + return nil, fmt.Errorf("getting interfaces addresses: %w", err) } - for _, iface := range ifaces { - var ifaceAddrs []net.Addr - ifaceAddrs, err = iface.Addrs() + for _, addr := range ifaceAddrs { + cidr := addr.String() + var ip net.IP + ip, _, err = net.ParseCIDR(cidr) if err != nil { - return nil, fmt.Errorf("getting addresses for %q: %w", iface.Name, err) + return nil, fmt.Errorf("parsing cidr: %w", err) } - for _, addr := range ifaceAddrs { - cidr := addr.String() - var ip net.IP - ip, _, err = net.ParseCIDR(cidr) - if err != nil { - return nil, fmt.Errorf("parsing cidr: %w", err) - } - - addrs = append(addrs, ip.String()) - } + addrs = append(addrs, ip.String()) } return addrs, nil diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index b0d11c52..296a18b0 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -4,10 +4,11 @@ package aghnet import ( + "bufio" + "bytes" "fmt" - "os" + "io" "regexp" - "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" @@ -23,7 +24,7 @@ type hardwarePortInfo struct { static bool } -func ifaceHasStaticIP(ifaceName string) (bool, error) { +func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { portInfo, err := getCurrentHardwarePortInfo(ifaceName) if err != nil { return false, err @@ -32,9 +33,10 @@ func ifaceHasStaticIP(ifaceName string) (bool, error) { return portInfo.static, nil } -// getCurrentHardwarePortInfo gets information for the specified network interface. +// getCurrentHardwarePortInfo gets information for the specified network +// interface. func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { - // First of all we should find hardware port name + // First of all we should find hardware port name. m := getNetworkSetupHardwareReports() hardwarePort, ok := m[ifaceName] if !ok { @@ -44,6 +46,10 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { return getHardwarePortInfo(hardwarePort) } +// hardwareReportsReg is the regular expression matching the lines of +// networksetup command output lines containing the interface information. +var hardwareReportsReg = regexp.MustCompile("Hardware Port: (.*?)\nDevice: (.*?)\n") + // getNetworkSetupHardwareReports parses the output of the `networksetup // -listallhardwareports` command it returns a map where the key is the // interface name, and the value is the "hardware port" returns nil if it fails @@ -52,54 +58,44 @@ func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) { // TODO(e.burkov): There should be more proper approach than parsing the // command output. For example, see // https://developer.apple.com/documentation/systemconfiguration. -func getNetworkSetupHardwareReports() map[string]string { - _, out, err := aghos.RunCommand("networksetup", "-listallhardwareports") +func getNetworkSetupHardwareReports() (reports map[string]string) { + _, out, err := aghosRunCommand("networksetup", "-listallhardwareports") if err != nil { return nil } - re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n") - if err != nil { - return nil + reports = make(map[string]string) + + matches := hardwareReportsReg.FindAllSubmatch(out, -1) + for _, m := range matches { + reports[string(m[2])] = string(m[1]) } - m := make(map[string]string) - - matches := re.FindAllStringSubmatch(out, -1) - for i := range matches { - port := matches[i][1] - device := matches[i][2] - m[device] = port - } - - return m + return reports } -func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) { - h := hardwarePortInfo{} +// hardwarePortReg is the regular expression matching the lines of networksetup +// command output lines containing the port information. +var hardwarePortReg = regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") - _, out, err := aghos.RunCommand("networksetup", "-getinfo", hardwarePort) +func getHardwarePortInfo(hardwarePort string) (h hardwarePortInfo, err error) { + _, out, err := aghosRunCommand("networksetup", "-getinfo", hardwarePort) if err != nil { return h, err } - re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n") - - match := re.FindStringSubmatch(out) - if len(match) == 0 { + match := hardwarePortReg.FindSubmatch(out) + if len(match) != 4 { return h, errors.Error("could not find hardware port info") } - h.name = hardwarePort - h.ip = match[1] - h.subnet = match[2] - h.gatewayIP = match[3] - - if strings.Index(out, "Manual Configuration") == 0 { - h.static = true - } - - return h, nil + return hardwarePortInfo{ + name: hardwarePort, + ip: string(match[1]), + subnet: string(match[2]), + gatewayIP: string(match[3]), + static: bytes.Index(out, []byte("Manual Configuration")) == 0, + }, nil } func ifaceSetStaticIP(ifaceName string) (err error) { @@ -109,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) { } if portInfo.static { - return errors.Error("IP address is already static") + return errors.Error("ip address is already static") } dnsAddrs, err := getEtcResolvConfServers() @@ -117,50 +113,62 @@ func ifaceSetStaticIP(ifaceName string) (err error) { return err } - args := make([]string, 0) - args = append(args, "-setdnsservers", portInfo.name) - args = append(args, dnsAddrs...) + args := append([]string{"-setdnsservers", portInfo.name}, dnsAddrs...) // Setting DNS servers is necessary when configuring a static IP - code, _, err := aghos.RunCommand("networksetup", args...) + code, _, err := aghosRunCommand("networksetup", args...) if err != nil { return err - } - if code != 0 { + } else if code != 0 { return fmt.Errorf("failed to set DNS servers, code=%d", code) } // Actually configures hardware port to have static IP - code, _, err = aghos.RunCommand("networksetup", "-setmanual", - portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP) + code, _, err = aghosRunCommand( + "networksetup", + "-setmanual", + portInfo.name, + portInfo.ip, + portInfo.subnet, + portInfo.gatewayIP, + ) if err != nil { return err - } - if code != 0 { + } else if code != 0 { return fmt.Errorf("failed to set DNS servers, code=%d", code) } return nil } +// etcResolvConfReg is the regular expression matching the lines of resolv.conf +// file containing a name server information. +var etcResolvConfReg = regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") + // getEtcResolvConfServers returns a list of nameservers configured in // /etc/resolv.conf. -func getEtcResolvConfServers() ([]string, error) { - body, err := os.ReadFile("/etc/resolv.conf") +func getEtcResolvConfServers() (addrs []string, err error) { + const filename = "etc/resolv.conf" + + _, err = aghos.FileWalker(func(r io.Reader) (_ []string, _ bool, err error) { + sc := bufio.NewScanner(r) + for sc.Scan() { + matches := etcResolvConfReg.FindAllStringSubmatch(sc.Text(), -1) + if len(matches) == 0 { + continue + } + + for _, m := range matches { + addrs = append(addrs, m[1]) + } + } + + return nil, false, sc.Err() + }).Walk(rootDirFS, filename) if err != nil { - return nil, err - } - - re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)") - - matches := re.FindAllStringSubmatch(string(body), -1) - if len(matches) == 0 { - return nil, errors.Error("found no DNS servers in /etc/resolv.conf") - } - - addrs := make([]string, 0) - for i := range matches { - addrs = append(addrs, matches[i][1]) + return nil, fmt.Errorf("parsing etc/resolv.conf file: %w", err) + } else if len(addrs) == 0 { + return nil, fmt.Errorf("found no dns servers in %s", filename) } return addrs, nil diff --git a/internal/aghnet/net_darwin_test.go b/internal/aghnet/net_darwin_test.go new file mode 100644 index 00000000..905600d5 --- /dev/null +++ b/internal/aghnet/net_darwin_test.go @@ -0,0 +1,261 @@ +package aghnet + +import ( + "io/fs" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" +) + +func TestIfaceHasStaticIP(t *testing.T) { + testCases := []struct { + name string + shell mapShell + ifaceName string + wantHas assert.BoolAssertionFunc + wantErrMsg string + }{{ + name: "success", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: ``, + }, { + name: "success_static", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "Manual Configuration\nIP address: 1.2.3.4\n" + + "Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.True, + wantErrMsg: ``, + }, { + name: "reports_error", + shell: theOnlyCmd( + "networksetup -listallhardwareports", + 0, + "", + errors.Error("can't list"), + ), + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `could not find hardware port for en0`, + }, { + name: "port_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: errors.Error("can't get"), + out: ``, + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `can't get`, + }, { + name: "port_bad_output", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "nothing meaningful", + code: 0, + }, + }, + ifaceName: "en0", + wantHas: assert.False, + wantErrMsg: `could not find hardware port info`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + + has, err := IfaceHasStaticIP(tc.ifaceName) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + tc.wantHas(t, has) + }) + } +} + +func TestIfaceSetStaticIP(t *testing.T) { + succFsys := fstest.MapFS{ + "etc/resolv.conf": &fstest.MapFile{ + Data: []byte(`nameserver 1.1.1.1`), + }, + } + panicFsys := &aghtest.FS{ + OnOpen: func(name string) (fs.File, error) { panic("not implemented") }, + } + + testCases := []struct { + name string + shell mapShell + fsys fs.FS + wantErrMsg string + }{{ + name: "success", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: nil, + out: "", + code: 0, + }, + "networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": { + err: nil, + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: ``, + }, { + name: "static_already", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "Manual Configuration\nIP address: 1.2.3.4\n" + + "Subnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + fsys: panicFsys, + wantErrMsg: `ip address is already static`, + }, { + name: "reports_error", + shell: theOnlyCmd( + "networksetup -listallhardwareports", + 0, + "", + errors.Error("can't list"), + ), + fsys: panicFsys, + wantErrMsg: `could not find hardware port for en0`, + }, { + name: "resolv_conf_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + }, + fsys: fstest.MapFS{ + "etc/resolv.conf": &fstest.MapFile{ + Data: []byte("this resolv.conf is invalid"), + }, + }, + wantErrMsg: `found no dns servers in etc/resolv.conf`, + }, { + name: "set_dns_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: errors.Error("can't set"), + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: `can't set`, + }, { + name: "set_manual_error", + shell: mapShell{ + "networksetup -listallhardwareports": { + err: nil, + out: "Hardware Port: hwport\nDevice: en0\n", + code: 0, + }, + "networksetup -getinfo hwport": { + err: nil, + out: "IP address: 1.2.3.4\nSubnet mask: 255.255.255.0\nRouter: 1.2.3.1\n", + code: 0, + }, + "networksetup -setdnsservers hwport 1.1.1.1": { + err: nil, + out: "", + code: 0, + }, + "networksetup -setmanual hwport 1.2.3.4 255.255.255.0 1.2.3.1": { + err: errors.Error("can't set"), + out: "", + code: 0, + }, + }, + fsys: succFsys, + wantErrMsg: `can't set`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + substRootDirFS(t, tc.fsys) + + err := IfaceSetStaticIP("en0") + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} diff --git a/internal/aghnet/net_freebsd.go b/internal/aghnet/net_freebsd.go index ea99b6fc..85d40184 100644 --- a/internal/aghnet/net_freebsd.go +++ b/internal/aghnet/net_freebsd.go @@ -18,7 +18,7 @@ func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { walker := aghos.FileWalker(interfaceName(ifaceName).rcConfStaticConfig) - return walker.Walk(aghos.RootDirFS(), rcConfFilename) + return walker.Walk(rootDirFS, rcConfFilename) } // rcConfStaticConfig checks if the interface is configured by /etc/rc.conf to diff --git a/internal/aghnet/net_freebsd_test.go b/internal/aghnet/net_freebsd_test.go index 3781b154..2c758360 100644 --- a/internal/aghnet/net_freebsd_test.go +++ b/internal/aghnet/net_freebsd_test.go @@ -4,56 +4,74 @@ package aghnet import ( - "strings" + "io/fs" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestRcConfStaticConfig(t *testing.T) { - const iface interfaceName = `em0` - const nl = "\n" +func TestIfaceHasStaticIP(t *testing.T) { + const ( + ifaceName = `em0` + rcConf = "etc/rc.conf" + ) testCases := []struct { - name string - rcconfData string - wantCont bool + name string + rootFsys fs.FS + wantHas assert.BoolAssertionFunc }{{ - name: "simple", - rcconfData: `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + name: "simple", + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl), + }}, + wantHas: assert.True, }, { - name: "case_insensitiveness", - rcconfData: `ifconfig_em0="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl, - wantCont: false, + name: "case_insensitiveness", + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `="InEt 127.0.0.253 NeTmAsK 0xffffffff"` + nl), + }}, + wantHas: assert.True, }, { name: "comments_and_trash", - rcconfData: `# comment 1` + nl + - `` + nl + - `# comment 2` + nl + - `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`# comment 1` + nl + + `` + nl + + `# comment 2` + nl + + `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl, + ), + }}, + wantHas: assert.True, }, { name: "aliases", - rcconfData: `ifconfig_em0_alias="inet 127.0.0.1/24"` + nl + - `ifconfig_em0="inet 127.0.0.253 netmask 0xffffffff"` + nl, - wantCont: false, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte(`ifconfig_` + ifaceName + `_alias="inet 127.0.0.1/24"` + nl + + `ifconfig_` + ifaceName + `="inet 127.0.0.253 netmask 0xffffffff"` + nl, + ), + }}, + wantHas: assert.True, }, { name: "incorrect_config", - rcconfData: `ifconfig_em0="inet6 127.0.0.253 netmask 0xffffffff"` + nl + - `ifconfig_em0="inet 256.256.256.256 netmask 0xffffffff"` + nl + - `ifconfig_em0=""` + nl, - wantCont: true, + rootFsys: fstest.MapFS{rcConf: &fstest.MapFile{ + Data: []byte( + `ifconfig_` + ifaceName + `="inet6 127.0.0.253 netmask 0xffffffff"` + nl + + `ifconfig_` + ifaceName + `="inet 256.256.256.256 netmask 0xffffffff"` + nl + + `ifconfig_` + ifaceName + `=""` + nl, + ), + }}, + wantHas: assert.False, }} for _, tc := range testCases { - r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - _, cont, err := iface.rcConfStaticConfig(r) + substRootDirFS(t, tc.rootFsys) + + has, err := IfaceHasStaticIP(ifaceName) require.NoError(t, err) - assert.Equal(t, tc.wantCont, cont) + tc.wantHas(t, has) }) } } diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index 93414165..148abe1f 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -13,16 +13,33 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/stringutil" "github.com/google/renameio/maybe" "golang.org/x/sys/unix" ) +// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem. +const dhcpcdConf = "etc/dhcpcd.conf" + +func canBindPrivilegedPorts() (can bool, err error) { + cnbs, err := unix.PrctlRetInt( + unix.PR_CAP_AMBIENT, + unix.PR_CAP_AMBIENT_IS_SET, + unix.CAP_NET_BIND_SERVICE, + 0, + 0, + ) + // Don't check the error because it's always nil on Linux. + adm, _ := aghos.HaveAdminRights() + + return cnbs == 1 || adm, err +} + // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to // have a static IP. func (n interfaceName) dhcpcdStaticConfig(r io.Reader) (subsources []string, cont bool, err error) { s := bufio.NewScanner(r) - ifaceFound := findIfaceLine(s, string(n)) - if !ifaceFound { + if !findIfaceLine(s, string(n)) { return nil, true, s.Err() } @@ -61,9 +78,9 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, fields := strings.Fields(line) fieldsNum := len(fields) - // Man page interfaces(5) declares that interface definition - // should consist of the key word "iface" followed by interface - // name, and method at fourth field. + // Man page interfaces(5) declares that interface definition should + // consist of the key word "iface" followed by interface name, and + // method at fourth field. if fieldsNum >= 4 && fields[0] == "iface" && fields[1] == string(n) && fields[3] == "static" { return nil, false, nil @@ -78,10 +95,10 @@ func (n interfaceName) ifacesStaticConfig(r io.Reader) (sub []string, cont bool, } func ifaceHasStaticIP(ifaceName string) (has bool, err error) { - // TODO(a.garipov): Currently, this function returns the first - // definitive result. So if /etc/dhcpcd.conf has a static IP while - // /etc/network/interfaces doesn't, it will return true. Perhaps this - // is not the most desirable behavior. + // TODO(a.garipov): Currently, this function returns the first definitive + // result. So if /etc/dhcpcd.conf has and /etc/network/interfaces has no + // static IP configuration, it will return true. Perhaps this is not the + // most desirable behavior. iface := interfaceName(ifaceName) @@ -90,17 +107,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { filename string }{{ FileWalker: iface.dhcpcdStaticConfig, - filename: "etc/dhcpcd.conf", + filename: dhcpcdConf, }, { FileWalker: iface.ifacesStaticConfig, filename: "etc/network/interfaces", }} { - has, err = pair.Walk(aghos.RootDirFS(), pair.filename) + has, err = pair.Walk(rootDirFS, pair.filename) if err != nil { return false, err - } - - if has { + } else if has { return true, nil } } @@ -108,14 +123,6 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) { return false, ErrNoStaticIPInfo } -func canBindPrivilegedPorts() (can bool, err error) { - cnbs, err := unix.PrctlRetInt(unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, unix.CAP_NET_BIND_SERVICE, 0, 0) - // Don't check the error because it's always nil on Linux. - adm, _ := aghos.HaveAdminRights() - - return cnbs == 1 || adm, err -} - // findIfaceLine scans s until it finds the line that declares an interface with // the given name. If findIfaceLine can't find the line, it returns false. func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { @@ -131,23 +138,23 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { } // ifaceSetStaticIP configures the system to retain its current IP on the -// interface through dhcpdc.conf. +// interface through dhcpcd.conf. func ifaceSetStaticIP(ifaceName string) (err error) { ipNet := GetSubnet(ifaceName) if ipNet.IP == nil { return errors.Error("can't get IP address") } - gatewayIP := GatewayIP(ifaceName) - add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP, ipNet.IP) - - body, err := os.ReadFile("/etc/dhcpcd.conf") + body, err := os.ReadFile(dhcpcdConf) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } + gatewayIP := GatewayIP(ifaceName) + add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP) + body = append(body, []byte(add)...) - err = maybe.WriteFile("/etc/dhcpcd.conf", body, 0o644) + err = maybe.WriteFile(dhcpcdConf, body, 0o644) if err != nil { return fmt.Errorf("writing conf: %w", err) } @@ -157,22 +164,24 @@ func ifaceSetStaticIP(ifaceName string) (err error) { // dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that // configure the interface to have a static IP. -func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gatewayIP, dnsIP net.IP) (conf string) { - var body []byte - - add := fmt.Sprintf( - "\n# %[1]s added by AdGuard Home.\ninterface %[1]s\nstatic ip_address=%s\n", +func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) { + b := &strings.Builder{} + stringutil.WriteToBuilder( + b, + "\n# ", ifaceName, - ipNet) - body = append(body, []byte(add)...) + " added by AdGuard Home.\ninterface ", + ifaceName, + "\nstatic ip_address=", + ipNet.String(), + "\n", + ) - if gatewayIP != nil { - add = fmt.Sprintf("static routers=%s\n", gatewayIP) - body = append(body, []byte(add)...) + if gwIP != nil { + stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n") } - add = fmt.Sprintf("static domain_name_servers=%s\n\n", dnsIP) - body = append(body, []byte(add)...) + stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n") - return string(body) + return b.String() } diff --git a/internal/aghnet/net_linux_test.go b/internal/aghnet/net_linux_test.go index bf2cecfe..838802ff 100644 --- a/internal/aghnet/net_linux_test.go +++ b/internal/aghnet/net_linux_test.go @@ -4,152 +4,124 @@ package aghnet import ( - "bytes" - "net" + "io/fs" "testing" + "testing/fstest" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestDHCPCDStaticConfig(t *testing.T) { - const iface interfaceName = `wlan0` +func TestHasStaticIP(t *testing.T) { + const ifaceName = "wlan0" + + const ( + dhcpcd = "etc/dhcpcd.conf" + netifaces = "etc/network/interfaces" + ) testCases := []struct { - name string - data []byte - wantCont bool - }{{ - name: "has_not", - data: []byte(`#comment` + nl + - `# comment` + nl + - `interface eth0` + nl + - `static ip_address=192.168.0.1/24` + nl + - `# interface ` + iface + nl + - `static ip_address=192.168.1.1/24` + nl + - `# comment` + nl, - ), - wantCont: true, - }, { - name: "has", - data: []byte(`#comment` + nl + - `# comment` + nl + - `interface eth0` + nl + - `static ip_address=192.168.0.1/24` + nl + - `# interface ` + iface + nl + - `static ip_address=192.168.1.1/24` + nl + - `# comment` + nl + - `interface ` + iface + nl + - `# comment` + nl + - `static ip_address=192.168.2.1/24` + nl, - ), - wantCont: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - r := bytes.NewReader(tc.data) - _, cont, err := iface.dhcpcdStaticConfig(r) - require.NoError(t, err) - - assert.Equal(t, tc.wantCont, cont) - }) - } -} - -func TestIfacesStaticConfig(t *testing.T) { - const iface interfaceName = `enp0s3` - - testCases := []struct { - name string - data []byte - wantCont bool - wantPatterns []string - }{{ - name: "has_not", - data: []byte(`allow-hotplug ` + iface + nl + - `#iface enp0s3 inet static` + nl + - `# address 192.168.0.200` + nl + - `# netmask 255.255.255.0` + nl + - `# gateway 192.168.0.1` + nl + - `iface ` + iface + ` inet dhcp` + nl, - ), - wantCont: true, - wantPatterns: []string{}, - }, { - name: "has", - data: []byte(`allow-hotplug ` + iface + nl + - `iface ` + iface + ` inet static` + nl + - ` address 192.168.0.200` + nl + - ` netmask 255.255.255.0` + nl + - ` gateway 192.168.0.1` + nl + - `#iface ` + iface + ` inet dhcp` + nl, - ), - wantCont: false, - wantPatterns: []string{}, - }, { - name: "return_patterns", - data: []byte(`source hello` + nl + - `source world` + nl + - `#iface ` + iface + ` inet static` + nl, - ), - wantCont: true, - wantPatterns: []string{"hello", "world"}, - }, { - // This one tests if the first found valid interface prevents - // checking files under the `source` directive. - name: "ignore_patterns", - data: []byte(`source hello` + nl + - `source world` + nl + - `iface ` + iface + ` inet static` + nl, - ), - wantCont: false, - wantPatterns: []string{}, - }} - - for _, tc := range testCases { - r := bytes.NewReader(tc.data) - t.Run(tc.name, func(t *testing.T) { - patterns, has, err := iface.ifacesStaticConfig(r) - require.NoError(t, err) - - assert.Equal(t, tc.wantCont, has) - assert.ElementsMatch(t, tc.wantPatterns, patterns) - }) - } -} - -func TestSetStaticIPdhcpcdConf(t *testing.T) { - testCases := []struct { + rootFsys fs.FS name string - dhcpcdConf string - routers net.IP + wantHas assert.BoolAssertionFunc + wantErrMsg string }{{ - name: "with_gateway", - dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl + - `interface wlan0` + nl + - `static ip_address=192.168.0.2/24` + nl + - `static routers=192.168.0.1` + nl + - `static domain_name_servers=192.168.0.2` + nl + nl, - routers: net.IP{192, 168, 0, 1}, + rootFsys: fstest.MapFS{ + dhcpcd: &fstest.MapFile{ + Data: []byte(`#comment` + nl + + `# comment` + nl + + `interface eth0` + nl + + `static ip_address=192.168.0.1/24` + nl + + `# interface ` + ifaceName + nl + + `static ip_address=192.168.1.1/24` + nl + + `# comment` + nl, + ), + }, + }, + name: "dhcpcd_has_not", + wantHas: assert.False, + wantErrMsg: `no information about static ip`, }, { - name: "without_gateway", - dhcpcdConf: nl + `# wlan0 added by AdGuard Home.` + nl + - `interface wlan0` + nl + - `static ip_address=192.168.0.2/24` + nl + - `static domain_name_servers=192.168.0.2` + nl + nl, - routers: nil, + rootFsys: fstest.MapFS{ + dhcpcd: &fstest.MapFile{ + Data: []byte(`#comment` + nl + + `# comment` + nl + + `interface ` + ifaceName + nl + + `static ip_address=192.168.0.1/24` + nl + + `# interface ` + ifaceName + nl + + `static ip_address=192.168.1.1/24` + nl + + `# comment` + nl, + ), + }, + }, + name: "dhcpcd_has", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`allow-hotplug ` + ifaceName + nl + + `#iface enp0s3 inet static` + nl + + `# address 192.168.0.200` + nl + + `# netmask 255.255.255.0` + nl + + `# gateway 192.168.0.1` + nl + + `iface ` + ifaceName + ` inet dhcp` + nl, + ), + }, + }, + name: "netifaces_has_not", + wantHas: assert.False, + wantErrMsg: `no information about static ip`, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`allow-hotplug ` + ifaceName + nl + + `iface ` + ifaceName + ` inet static` + nl + + ` address 192.168.0.200` + nl + + ` netmask 255.255.255.0` + nl + + ` gateway 192.168.0.1` + nl + + `#iface ` + ifaceName + ` inet dhcp` + nl, + ), + }, + }, + name: "netifaces_has", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`source hello` + nl + + `#iface ` + ifaceName + ` inet static` + nl, + ), + }, + "hello": &fstest.MapFile{ + Data: []byte(`iface ` + ifaceName + ` inet static` + nl), + }, + }, + name: "netifaces_another_file", + wantHas: assert.True, + wantErrMsg: ``, + }, { + rootFsys: fstest.MapFS{ + netifaces: &fstest.MapFile{ + Data: []byte(`source hello` + nl + + `iface ` + ifaceName + ` inet static` + nl, + ), + }, + }, + name: "netifaces_ignore_another", + wantHas: assert.True, + wantErrMsg: ``, }} - ipNet := &net.IPNet{ - IP: net.IP{192, 168, 0, 2}, - Mask: net.IPMask{255, 255, 255, 0}, - } - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - s := dhcpcdConfIface("wlan0", ipNet, tc.routers, net.IP{192, 168, 0, 2}) - assert.Equal(t, tc.dhcpcdConf, s) + substRootDirFS(t, tc.rootFsys) + + has, err := IfaceHasStaticIP(ifaceName) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + tc.wantHas(t, has) }) } } diff --git a/internal/aghnet/net_openbsd.go b/internal/aghnet/net_openbsd.go index 51ef5d44..cf911105 100644 --- a/internal/aghnet/net_openbsd.go +++ b/internal/aghnet/net_openbsd.go @@ -16,7 +16,7 @@ import ( func ifaceHasStaticIP(ifaceName string) (ok bool, err error) { filename := fmt.Sprintf("etc/hostname.%s", ifaceName) - return aghos.FileWalker(hostnameIfStaticConfig).Walk(aghos.RootDirFS(), filename) + return aghos.FileWalker(hostnameIfStaticConfig).Walk(rootDirFS, filename) } // hostnameIfStaticConfig checks if the interface is configured by diff --git a/internal/aghnet/net_openbsd_test.go b/internal/aghnet/net_openbsd_test.go index e157d93a..356799b7 100644 --- a/internal/aghnet/net_openbsd_test.go +++ b/internal/aghnet/net_openbsd_test.go @@ -4,49 +4,69 @@ package aghnet import ( - "strings" + "fmt" + "io/fs" "testing" + "testing/fstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestHostnameIfStaticConfig(t *testing.T) { - const nl = "\n" +func TestIfaceHasStaticIP(t *testing.T) { + const ifaceName = "em0" + + confFile := fmt.Sprintf("etc/hostname.%s", ifaceName) testCases := []struct { - name string - rcconfData string - wantHas bool + name string + rootFsys fs.FS + wantHas assert.BoolAssertionFunc }{{ - name: "simple", - rcconfData: `inet 127.0.0.253` + nl, - wantHas: true, + name: "simple", + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`inet 127.0.0.253` + nl), + }, + }, + wantHas: assert.True, }, { - name: "case_sensitiveness", - rcconfData: `InEt 127.0.0.253` + nl, - wantHas: false, + name: "case_sensitiveness", + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`InEt 127.0.0.253` + nl), + }, + }, + wantHas: assert.False, }, { name: "comments_and_trash", - rcconfData: `# comment 1` + nl + - `` + nl + - `# inet 127.0.0.253` + nl + - `inet` + nl, - wantHas: false, + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`# comment 1` + nl + nl + + `# inet 127.0.0.253` + nl + + `inet` + nl, + ), + }, + }, + wantHas: assert.False, }, { name: "incorrect_config", - rcconfData: `inet6 127.0.0.253` + nl + - `inet 256.256.256.256` + nl, - wantHas: false, + rootFsys: fstest.MapFS{ + confFile: &fstest.MapFile{ + Data: []byte(`inet6 127.0.0.253` + nl + `inet 256.256.256.256` + nl), + }, + }, + wantHas: assert.False, }} for _, tc := range testCases { - r := strings.NewReader(tc.rcconfData) t.Run(tc.name, func(t *testing.T) { - _, has, err := hostnameIfStaticConfig(r) + substRootDirFS(t, tc.rootFsys) + + has, err := IfaceHasStaticIP(ifaceName) require.NoError(t, err) - assert.Equal(t, tc.wantHas, has) + tc.wantHas(t, has) }) } } diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 72d80b51..40d395ba 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,12 +1,17 @@ package aghnet import ( + "bytes" + "encoding/json" + "fmt" "io/fs" "net" "os" + "strings" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" @@ -20,6 +25,113 @@ func TestMain(m *testing.M) { // testdata is the filesystem containing data for testing the package. var testdata fs.FS = os.DirFS("./testdata") +// substRootDirFS replaces the aghos.RootDirFS function used throughout the +// package with fsys for tests ran under t. +func substRootDirFS(t testing.TB, fsys fs.FS) { + t.Helper() + + prev := rootDirFS + t.Cleanup(func() { rootDirFS = prev }) + rootDirFS = fsys +} + +// RunCmdFunc is the signature of aghos.RunCommand function. +type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error) + +// substShell replaces the the aghos.RunCommand function used throughout the +// package with rc for tests ran under t. +func substShell(t testing.TB, rc RunCmdFunc) { + t.Helper() + + prev := aghosRunCommand + t.Cleanup(func() { aghosRunCommand = prev }) + aghosRunCommand = rc +} + +// mapShell is a substitution of aghos.RunCommand that maps the command to it's +// execution result. It's only needed to simplify testing. +// +// TODO(e.burkov): Perhaps put all the shell interactions behind an interface. +type mapShell map[string]struct { + err error + out string + code int +} + +// theOnlyCmd returns mapShell that only handles a single command and arguments +// combination from cmd. +func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { + return mapShell{cmd: {code: code, out: out, err: err}} +} + +// RunCmd is a RunCmdFunc handled by s. +func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) { + key := strings.Join(append([]string{cmd}, args...), " ") + ret, ok := s[key] + if !ok { + return 0, nil, fmt.Errorf("unexpected shell command %q", key) + } + + return ret.code, []byte(ret.out), ret.err +} + +// ifaceAddrsFunc is the signature of net.InterfaceAddrs function. +type ifaceAddrsFunc func() (ifaces []net.Addr, err error) + +// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used +// throughout the package with f for tests ran under t. +func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) { + t.Helper() + + prev := netInterfaceAddrs + t.Cleanup(func() { netInterfaceAddrs = prev }) + netInterfaceAddrs = f +} + +func TestGatewayIP(t *testing.T) { + const ifaceName = "ifaceName" + const cmd = "ip route show dev " + ifaceName + + testCases := []struct { + name string + shell mapShell + want net.IP + }{{ + name: "success_v4", + shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil), + want: net.IP{1, 2, 3, 4}.To16(), + }, { + name: "success_v6", + shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil), + want: net.IP{ + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0xFF, 0xFF, + }, + }, { + name: "bad_output", + shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil), + want: nil, + }, { + name: "err_runcmd", + shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")), + want: nil, + }, { + name: "bad_code", + shell: theOnlyCmd(cmd, 1, "", nil), + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + + assert.Equal(t, tc.want, GatewayIP(ifaceName)) + }) + } +} + func TestGetInterfaceByIP(t *testing.T) { ifaces, err := GetValidNetInterfacesForWeb() require.NoError(t, err) @@ -130,3 +242,107 @@ func TestCheckPort(t *testing.T) { assert.NoError(t, err) }) } + +func TestCollectAllIfacesAddrs(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + addrs []net.Addr + wantAddrs []string + }{{ + name: "success", + wantErrMsg: ``, + addrs: []net.Addr{&net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.CIDRMask(24, netutil.IPv4BitLen), + }, &net.IPNet{ + IP: net.IP{4, 3, 2, 1}, + Mask: net.CIDRMask(16, netutil.IPv4BitLen), + }}, + wantAddrs: []string{"1.2.3.4", "4.3.2.1"}, + }, { + name: "not_cidr", + wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`, + addrs: []net.Addr{&net.IPAddr{ + IP: net.IP{1, 2, 3, 4}, + }}, + wantAddrs: nil, + }, { + name: "empty", + wantErrMsg: ``, + addrs: []net.Addr{}, + wantAddrs: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil }) + + addrs, err := CollectAllIfacesAddrs() + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.wantAddrs, addrs) + }) + } + + t.Run("internal_error", func(t *testing.T) { + const errAddrs errors.Error = "can't get addresses" + const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs) + + substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs }) + + _, err := CollectAllIfacesAddrs() + testutil.AssertErrorMsg(t, wantErrMsg, err) + }) +} + +func TestIsAddrInUse(t *testing.T) { + t.Run("addr_in_use", func(t *testing.T) { + l, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, l.Close) + + _, err = net.Listen(l.Addr().Network(), l.Addr().String()) + assert.True(t, IsAddrInUse(err)) + }) + + t.Run("another", func(t *testing.T) { + const anotherErr errors.Error = "not addr in use" + + assert.False(t, IsAddrInUse(anotherErr)) + }) +} + +func TestNetInterface_MarshalJSON(t *testing.T) { + const want = `{` + + `"hardware_address":"aa:bb:cc:dd:ee:ff",` + + `"flags":"up|multicast",` + + `"ip_addresses":["1.2.3.4","aaaa::1"],` + + `"name":"iface0",` + + `"mtu":1500` + + `}` + "\n" + + ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen) + + iface := &NetInterface{ + Addresses: []net.IP{ip4, ip6}, + Subnets: []*net.IPNet{{ + IP: ip4.Mask(mask4), + Mask: mask4, + }, { + IP: ip6.Mask(mask6), + Mask: mask6, + }}, + Name: "iface0", + HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, + Flags: net.FlagUp | net.FlagMulticast, + MTU: 1500, + } + + b := &bytes.Buffer{} + err := json.NewEncoder(b).Encode(iface) + require.NoError(t, err) + + assert.Equal(t, want, b.String()) +} diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 777127a3..13fbeb32 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -1,11 +1,5 @@ package aghnet -import ( - "time" - - "github.com/AdguardTeam/golibs/log" -) - // DefaultRefreshIvl is the default period of time between refreshing cached // addresses. // const DefaultRefreshIvl = 5 * time.Minute @@ -16,39 +10,21 @@ type HostGenFunc func() (host string) // SystemResolvers helps to work with local resolvers' addresses provided by OS. type SystemResolvers interface { - // Get returns the slice of local resolvers' addresses. It should be - // safe for concurrent use. + // Get returns the slice of local resolvers' addresses. It must be safe for + // concurrent use. Get() (rs []string) - // refresh refreshes the local resolvers' addresses cache. It should be - // safe for concurrent use. + // refresh refreshes the local resolvers' addresses cache. It must be safe + // for concurrent use. refresh() (err error) } -// refreshWithTicker refreshes the cache of sr after each tick form tickCh. -func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { - defer log.OnPanic("systemResolvers") - - // TODO(e.burkov): Implement a functionality to stop ticker. - for range tickCh { - err := sr.refresh() - if err != nil { - log.Error("systemResolvers: error in refreshing goroutine: %s", err) - - continue - } - - log.Debug("systemResolvers: local addresses cache is refreshed") - } -} - // NewSystemResolvers returns a SystemResolvers with the cache refresh rate // defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If // nil is passed for hostGenFunc, the default generator will be used. func NewSystemResolvers( - refreshIvl time.Duration, hostGenFunc HostGenFunc, ) (sr SystemResolvers, err error) { - sr = newSystemResolvers(refreshIvl, hostGenFunc) + sr = newSystemResolvers(hostGenFunc) // Fill cache. err = sr.refresh() @@ -56,11 +32,5 @@ func NewSystemResolvers( return nil, err } - if refreshIvl > 0 { - ticker := time.NewTicker(refreshIvl) - - go refreshWithTicker(sr, ticker.C) - } - return sr, nil } diff --git a/internal/aghnet/systemresolvers_others.go b/internal/aghnet/systemresolvers_others.go index 8acdb6c7..f8afa286 100644 --- a/internal/aghnet/systemresolvers_others.go +++ b/internal/aghnet/systemresolvers_others.go @@ -24,12 +24,15 @@ func defaultHostGen() (host string) { // systemResolvers is a default implementation of SystemResolvers interface. type systemResolvers struct { - resolver *net.Resolver - hostGenFunc HostGenFunc - - // addrs is the set that contains cached local resolvers' addresses. - addrs *stringutil.Set + // addrsLock protects addrs. addrsLock sync.RWMutex + // addrs is the set that contains cached local resolvers' addresses. + addrs *stringutil.Set + + // resolver is used to fetch the resolvers' addresses. + resolver *net.Resolver + // hostGenFunc generates hosts to resolve. + hostGenFunc HostGenFunc } const ( @@ -44,6 +47,7 @@ const ( errUnexpectedHostFormat errors.Error = "unexpected host format" ) +// refresh implements the SystemResolvers interface for *systemResolvers. func (sr *systemResolvers) refresh() (err error) { defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() @@ -56,7 +60,7 @@ func (sr *systemResolvers) refresh() (err error) { return err } -func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr SystemResolvers) { +func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) { if hostGenFunc == nil { hostGenFunc = defaultHostGen } @@ -76,19 +80,18 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S func validateDialedHost(host string) (err error) { defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }() - var ipStr string parts := strings.Split(host, "%") switch len(parts) { case 1: - ipStr = host + // host case 2: // Remove the zone and check the IP address part. - ipStr = parts[0] + host = parts[0] default: return errUnexpectedHostFormat } - if net.ParseIP(ipStr) == nil { + if _, err = netutil.ParseIP(host); err != nil { return errBadAddrPassed } diff --git a/internal/aghnet/systemresolvers_others_test.go b/internal/aghnet/systemresolvers_others_test.go index 79abeca2..f7cf9ef0 100644 --- a/internal/aghnet/systemresolvers_others_test.go +++ b/internal/aghnet/systemresolvers_others_test.go @@ -6,37 +6,32 @@ package aghnet import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createTestSystemResolversImp( +func createTestSystemResolversImpl( t *testing.T, - refreshDur time.Duration, hostGenFunc HostGenFunc, ) (imp *systemResolvers) { t.Helper() - sr := createTestSystemResolvers(t, refreshDur, hostGenFunc) + sr := createTestSystemResolvers(t, hostGenFunc) + require.IsType(t, (*systemResolvers)(nil), sr) - var ok bool - imp, ok = sr.(*systemResolvers) - require.True(t, ok) - - return imp + return sr.(*systemResolvers) } func TestSystemResolvers_Refresh(t *testing.T) { t.Run("expected_error", func(t *testing.T) { - sr := createTestSystemResolvers(t, 0, nil) + sr := createTestSystemResolvers(t, nil) assert.NoError(t, sr.refresh()) }) t.Run("unexpected_error", func(t *testing.T) { - _, err := NewSystemResolvers(0, func() string { + _, err := NewSystemResolvers(func() string { return "127.0.0.1::123" }) assert.Error(t, err) @@ -44,7 +39,7 @@ func TestSystemResolvers_Refresh(t *testing.T) { } func TestSystemResolvers_DialFunc(t *testing.T) { - imp := createTestSystemResolversImp(t, 0, nil) + imp := createTestSystemResolversImpl(t, nil) testCases := []struct { want error @@ -52,7 +47,7 @@ func TestSystemResolvers_DialFunc(t *testing.T) { address string }{{ want: errFakeDial, - name: "valid", + name: "valid_ipv4", address: "127.0.0.1", }, { want: errFakeDial, diff --git a/internal/aghnet/systemresolvers_test.go b/internal/aghnet/systemresolvers_test.go index 13145817..0a19490d 100644 --- a/internal/aghnet/systemresolvers_test.go +++ b/internal/aghnet/systemresolvers_test.go @@ -2,7 +2,6 @@ package aghnet import ( "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -10,13 +9,12 @@ import ( func createTestSystemResolvers( t *testing.T, - refreshDur time.Duration, hostGenFunc HostGenFunc, ) (sr SystemResolvers) { t.Helper() var err error - sr, err = NewSystemResolvers(refreshDur, hostGenFunc) + sr, err = NewSystemResolvers(hostGenFunc) require.NoError(t, err) require.NotNil(t, sr) @@ -24,8 +22,14 @@ func createTestSystemResolvers( } func TestSystemResolvers_Get(t *testing.T) { - sr := createTestSystemResolvers(t, 0, nil) - assert.NotEmpty(t, sr.Get()) + sr := createTestSystemResolvers(t, nil) + + var rs []string + require.NotPanics(t, func() { + rs = sr.Get() + }) + + assert.NotEmpty(t, rs) } // TODO(e.burkov): Write tests for refreshWithTicker. diff --git a/internal/aghnet/systemresolvers_windows.go b/internal/aghnet/systemresolvers_windows.go index 5acdfa85..f82d6e7e 100644 --- a/internal/aghnet/systemresolvers_windows.go +++ b/internal/aghnet/systemresolvers_windows.go @@ -11,7 +11,6 @@ import ( "os/exec" "strings" "sync" - "time" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -27,7 +26,7 @@ type systemResolvers struct { addrsLock sync.RWMutex } -func newSystemResolvers(refreshIvl time.Duration, _ HostGenFunc) (sr SystemResolvers) { +func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) { return &systemResolvers{} } diff --git a/internal/aghnet/testdata/proc_net_arp b/internal/aghnet/testdata/proc_net_arp new file mode 100644 index 00000000..8460c8bb --- /dev/null +++ b/internal/aghnet/testdata/proc_net_arp @@ -0,0 +1,6 @@ +IP address HW type Flags HW address Mask Device +192.168.1.2 0x1 0x2 ab:cd:ef:ab:cd:ef * wan +::ffff:ffff 0x1 0x0 ef:cd:ab:ef:cd:ab * br-lan +0.0.0.0 0x0 0x0 00:00:00:00:00:00 * unspec +1.2.3.4.5 0x1 0x2 aa:bb:cc:dd:ee:ff * wan +1.2.3.4 0x1 0x2 12:34:56:78:910 * wan \ No newline at end of file diff --git a/internal/aghos/os.go b/internal/aghos/os.go index 29eb1afc..3b688749 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -52,24 +52,27 @@ func HaveAdminRights() (bool, error) { return haveAdminRights() } -// MaxCmdOutputSize is the maximum length of performed shell command output. -const MaxCmdOutputSize = 2 * 1024 +// MaxCmdOutputSize is the maximum length of performed shell command output in +// bytes. +const MaxCmdOutputSize = 64 * 1024 // RunCommand runs shell command. -func RunCommand(command string, arguments ...string) (int, string, error) { +func RunCommand(command string, arguments ...string) (code int, output []byte, err error) { cmd := exec.Command(command, arguments...) out, err := cmd.Output() if len(out) > MaxCmdOutputSize { out = out[:MaxCmdOutputSize] } - if errors.As(err, new(*exec.ExitError)) { - return cmd.ProcessState.ExitCode(), string(out), nil - } else if err != nil { - return 1, "", fmt.Errorf("exec.Command(%s) failed: %w: %s", command, err, string(out)) + if err != nil { + if eerr := new(exec.ExitError); errors.As(err, &eerr) { + return eerr.ExitCode(), eerr.Stderr, nil + } + + return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out) } - return cmd.ProcessState.ExitCode(), string(out), nil + return cmd.ProcessState.ExitCode(), out, nil } // PIDByCommand searches for process named command and returns its PID ignoring @@ -172,3 +175,13 @@ func RootDirFS() (fsys fs.FS) { // behavior is undocumented but it currently works. return os.DirFS("") } + +// NotifyShutdownSignal notifies c on receiving shutdown signals. +func NotifyShutdownSignal(c chan<- os.Signal) { + notifyShutdownSignal(c) +} + +// IsShutdownSignal returns true if sig is a shutdown signal. +func IsShutdownSignal(sig os.Signal) (ok bool) { + return isShutdownSignal(sig) +} diff --git a/internal/aghos/os_unix.go b/internal/aghos/os_unix.go new file mode 100644 index 00000000..9a3cc308 --- /dev/null +++ b/internal/aghos/os_unix.go @@ -0,0 +1,27 @@ +//go:build darwin || freebsd || linux || openbsd +// +build darwin freebsd linux openbsd + +package aghos + +import ( + "os" + "os/signal" + + "golang.org/x/sys/unix" +) + +func notifyShutdownSignal(c chan<- os.Signal) { + signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) +} + +func isShutdownSignal(sig os.Signal) (ok bool) { + switch sig { + case + unix.SIGINT, + unix.SIGQUIT, + unix.SIGTERM: + return true + default: + return false + } +} diff --git a/internal/aghos/os_windows.go b/internal/aghos/os_windows.go index bff5a3f0..31fca3ef 100644 --- a/internal/aghos/os_windows.go +++ b/internal/aghos/os_windows.go @@ -4,6 +4,10 @@ package aghos import ( + "os" + "os/signal" + "syscall" + "golang.org/x/sys/windows" ) @@ -35,3 +39,20 @@ func haveAdminRights() (bool, error) { func isOpenWrt() (ok bool) { return false } + +func notifyShutdownSignal(c chan<- os.Signal) { + // syscall.SIGTERM is processed automatically. See go doc os/signal, + // section Windows. + signal.Notify(c, os.Interrupt) +} + +func isShutdownSignal(sig os.Signal) (ok bool) { + switch sig { + case + os.Interrupt, + syscall.SIGTERM: + return true + default: + return false + } +} diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 02816265..5e0b8293 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -135,7 +135,6 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { pctx.Res = s.genNXDomain(pctx.Req) return resultCodeFinish - } return resultCodeSuccess diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 2d32cfd2..4afdfd34 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -173,7 +173,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { // TODO(e.burkov): Enable the refresher after the actual implementation // passes the public testing. - s.sysResolvers, err = aghnet.NewSystemResolvers(0, nil) + s.sysResolvers, err = aghnet.NewSystemResolvers(nil) if err != nil { return nil, fmt.Errorf("initializing system resolvers: %w", err) } diff --git a/internal/dnsforward/recursiondetector_test.go b/internal/dnsforward/recursiondetector_test.go index 7573b668..4edb3a37 100644 --- a/internal/dnsforward/recursiondetector_test.go +++ b/internal/dnsforward/recursiondetector_test.go @@ -83,7 +83,7 @@ func TestRecursionDetector_Suspect(t *testing.T) { testCases := []struct { name string msg dns.Msg - want bool + want int }{{ name: "simple", msg: dns.Msg{ @@ -95,24 +95,18 @@ func TestRecursionDetector_Suspect(t *testing.T) { Qtype: dns.TypeA, }}, }, - want: true, + want: 1, }, { name: "unencumbered", msg: dns.Msg{}, - want: false, + want: 0, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Cleanup(rd.clear) - rd.add(tc.msg) - - if tc.want { - assert.Equal(t, 1, rd.recentRequests.Stats().Count) - } else { - assert.Zero(t, rd.recentRequests.Stats().Count) - } + assert.Equal(t, tc.want, rd.recentRequests.Stats().Count) }) } } diff --git a/internal/home/home.go b/internal/home/home.go index cc8d7e67..b6d3c223 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -518,44 +518,31 @@ func StartMods() error { func checkPermissions() { log.Info("Checking if AdGuard Home has necessary permissions") - if runtime.GOOS == "windows" { - // On Windows we need to have admin rights to run properly - - admin, _ := aghos.HaveAdminRights() - if admin { - return - } - + if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil { log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.") } // We should check if AdGuard Home is able to bind to port 53 - ok, err := aghnet.CanBindPort(53) - - if ok { - log.Info("AdGuard Home can bind to port 53") - return - } - - if errors.Is(err, os.ErrPermission) { - msg := `Permission check failed. - + err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS) + if err != nil { + if errors.Is(err, os.ErrPermission) { + log.Fatal(`Permission check failed. AdGuard Home is not allowed to bind to privileged ports (for instance, port 53). Please note, that this is crucial for a server to be able to use privileged ports. - You have two options: 1. Run AdGuard Home with root privileges 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: -https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser` +https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`) + } - log.Fatal(msg) + log.Info( + "AdGuard failed to bind to port 53: %s\n\n"+ + "Please note, that this is crucial for a DNS server to be able to use that port.", + err, + ) } - msg := fmt.Sprintf(`AdGuard failed to bind to port 53 due to %v - -Please note, that this is crucial for a DNS server to be able to use that port.`, err) - - log.Info(msg) + log.Info("AdGuard Home can bind to port 53") } // Write PID to a file diff --git a/internal/home/rdns.go b/internal/home/rdns.go index cba748af..9f577c24 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -16,18 +16,17 @@ type RDNS struct { exchanger dnsforward.RDNSExchanger clients *clientsContainer - // usePrivate is used to store the state of current private RDNS - // resolving settings and to react to it's changes. + // usePrivate is used to store the state of current private RDNS resolving + // settings and to react to it's changes. usePrivate uint32 // ipCh used to pass client's IP to rDNS workerLoop. ipCh chan net.IP // ipCache caches the IP addresses to be resolved by rDNS. The resolved - // address stays here while it's inside clients. After leaving clients - // the address will be resolved once again. If the address couldn't be - // resolved, cache prevents further attempts to resolve it for some - // time. + // address stays here while it's inside clients. After leaving clients the + // address will be resolved once again. If the address couldn't be + // resolved, cache prevents further attempts to resolve it for some time. ipCache cache.Cache } diff --git a/internal/home/service_openbsd.go b/internal/home/service_openbsd.go index 679a7437..8ad0d212 100644 --- a/internal/home/service_openbsd.go +++ b/internal/home/service_openbsd.go @@ -314,12 +314,13 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) { // TODO(e.burkov): It's possible that os.ErrNotExist is caused by // something different than the service script's non-existence. Keep it // in mind, when replace the aghos.RunCommand. - _, out, err = aghos.RunCommand(scriptPath, cmd) + var outData []byte + _, outData, err = aghos.RunCommand(scriptPath, cmd) if errors.Is(err, os.ErrNotExist) { return "", service.ErrNotInstalled } - return out, err + return string(outData), err } // Status implements service.Service interface for *openbsdRunComService.