Pull request: all: imp cyclo in new code

Updates #2646,

Squashed commit of the following:

commit af6a6fa2b7229bc0f1c7c9083b0391a6bec7ae70
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 31 20:00:36 2021 +0300

    all: imp code, docs

commit 1cd4781b13e635a9e1bccb758104c1b76c78d34e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 31 18:51:23 2021 +0300

    all: imp cyclo in new code
This commit is contained in:
Ainar Garipov 2021-05-31 20:11:06 +03:00
parent c95acf73ab
commit e17e1f20fb
10 changed files with 211 additions and 190 deletions

View File

@ -2,7 +2,6 @@ package aghnet
import ( import (
"bufio" "bufio"
"io"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -231,14 +230,41 @@ func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHo
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
} }
// Read IP-hostname pairs from file // parseHostsLine parses hosts from the fields.
// Multiple hostnames per line (per one IP) is supported. func parseHostsLine(fields []string) (hosts []string) {
func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[string][]string, fn string) { for _, f := range fields {
hashIdx := strings.IndexByte(f, '#')
if hashIdx == 0 {
// The rest of the fields are a part of the comment.
// Skip immediately.
return
} else if hashIdx > 0 {
// Only a part of the field is a comment.
hosts = append(hosts, f[:hashIdx])
return hosts
}
hosts = append(hosts, f)
}
return hosts
}
// load reads IP-hostname pairs from the hosts file. Multiple hostnames per
// line for one IP are supported.
func (ehc *EtcHostsContainer) load(
table map[string][]net.IP,
tableRev map[string][]string,
fn string,
) {
f, err := os.Open(fn) f, err := os.Open(fn)
if err != nil { if err != nil {
log.Error("etchostscontainer: %s", err) log.Error("etchostscontainer: %s", err)
return return
} }
defer func() { defer func() {
derr := f.Close() derr := f.Close()
if derr != nil { if derr != nil {
@ -246,25 +272,11 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
} }
}() }()
r := bufio.NewReader(f)
log.Debug("etchostscontainer: loading hosts from file %s", fn) log.Debug("etchostscontainer: loading hosts from file %s", fn)
for done := false; !done; { s := bufio.NewScanner(f)
var line string for s.Scan() {
line, err = r.ReadString('\n') line := strings.TrimSpace(s.Text())
if err == io.EOF {
done = true
} else if err != nil {
log.Error("etchostscontainer: %s", err)
return
}
line = strings.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) < 2 { if len(fields) < 2 {
continue continue
@ -275,28 +287,17 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
continue continue
} }
for i := 1; i != len(fields); i++ { hosts := parseHostsLine(fields[1:])
host := fields[i] for _, host := range hosts {
if len(host) == 0 {
break
}
sharp := strings.IndexByte(host, '#')
if sharp == 0 {
// Skip the comments.
break
} else if sharp > 0 {
host = host[:sharp]
}
ehc.updateTable(table, host, ip) ehc.updateTable(table, host, ip)
ehc.updateTableRev(tableRev, host, ip) ehc.updateTableRev(tableRev, host, ip)
if sharp >= 0 {
// Skip the comments again.
break
}
} }
} }
err = s.Err()
if err != nil {
log.Error("etchostscontainer: %s", err)
}
} }
// onlyWrites is a filter for (*fsnotify.Watcher).Events. // onlyWrites is a filter for (*fsnotify.Watcher).Events.

View File

@ -23,10 +23,11 @@ func prepareTestFile(t *testing.T) (f *os.File) {
dir := t.TempDir() dir := t.TempDir()
f, err := os.CreateTemp(dir, "") f, err := os.CreateTemp(dir, "")
require.Nil(t, err) require.NoError(t, err)
require.NotNil(t, f) require.NotNil(t, f)
t.Cleanup(func() { t.Cleanup(func() {
assert.Nil(t, f.Close()) assert.NoError(t, f.Close())
}) })
return f return f
@ -37,7 +38,7 @@ func assertWriting(t *testing.T, f *os.File, strs ...string) {
for _, str := range strs { for _, str := range strs {
n, err := f.WriteString(str) n, err := f.WriteString(str)
require.Nil(t, err) require.NoError(t, err)
assert.Equal(t, n, len(str)) assert.Equal(t, n, len(str))
} }
} }
@ -77,16 +78,16 @@ func TestEtcHostsContainerResolution(t *testing.T) {
t.Run("ptr", func(t *testing.T) { t.Run("ptr", func(t *testing.T) {
testCases := []struct { testCases := []struct {
wantIP string wantIP string
wantLen int
wantHost string wantHost string
wantLen int
}{ }{
{wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"}, {wantIP: "127.0.0.1", wantHost: "host", wantLen: 2},
{wantIP: "::1", wantLen: 1, wantHost: "localhost"}, {wantIP: "::1", wantHost: "localhost", wantLen: 1},
} }
for _, tc := range testCases { for _, tc := range testCases {
a, err := dns.ReverseAddr(tc.wantIP) a, err := dns.ReverseAddr(tc.wantIP)
require.Nil(t, err) require.NoError(t, err)
a = strings.TrimSuffix(a, ".") a = strings.TrimSuffix(a, ".")
hosts := ehc.ProcessReverse(a, dns.TypePTR) hosts := ehc.ProcessReverse(a, dns.TypePTR)
@ -114,7 +115,7 @@ func TestEtcHostsContainerFSNotify(t *testing.T) {
t.Cleanup(ehc.Close) t.Cleanup(ehc.Close)
assertWriting(t, f, "127.0.0.2 newhost\n") assertWriting(t, f, "127.0.0.2 newhost\n")
require.Nil(t, f.Sync()) require.NoError(t, f.Sync())
// Wait until fsnotify has triggerred and processed the // Wait until fsnotify has triggerred and processed the
// file-modification event. // file-modification event.

View File

@ -68,40 +68,41 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, ErrNoStaticIPInfo return false, ErrNoStaticIPInfo
} }
// findIfaceLine scans s until it finds the line that declares an interface with
// the given name. If findIfaceLine can't find the line, it returns false.
func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) == 2 && fields[0] == "interface" && fields[1] == name {
return true
}
}
return false
}
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to // dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
// have a static IP. // have a static IP.
func dhcpcdStaticConfig(r io.Reader, ifaceName string) (has bool, err error) { func dhcpcdStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
var withinInterfaceCtx bool ifaceFound := findIfaceLine(s, ifaceName)
if !ifaceFound {
return false, s.Err()
}
for s.Scan() { for s.Scan() {
line := strings.TrimSpace(s.Text()) line := strings.TrimSpace(s.Text())
if withinInterfaceCtx && len(line) == 0 {
// An empty line resets our state.
withinInterfaceCtx = false
}
if len(line) == 0 || line[0] == '#' {
continue
}
fields := strings.Fields(line) fields := strings.Fields(line)
if len(fields) >= 2 &&
if withinInterfaceCtx { fields[0] == "static" &&
if len(fields) >= 2 && fields[0] == "static" && strings.HasPrefix(fields[1], "ip_address=") { strings.HasPrefix(fields[1], "ip_address=") {
return true, nil return true, s.Err()
}
if len(fields) > 0 && fields[0] == "interface" {
// Another interface found.
withinInterfaceCtx = false
}
continue
} }
if len(fields) == 2 && fields[0] == "interface" && fields[1] == ifaceName { if len(fields) > 0 && fields[0] == "interface" {
// The interface found. // Another interface found.
withinInterfaceCtx = true break
} }
} }

View File

@ -3,7 +3,6 @@ package aghnet
import ( import (
"time" "time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
@ -25,18 +24,6 @@ type SystemResolvers interface {
refresh() (err error) refresh() (err error)
} }
const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
// refreshWithTicker refreshes the cache of sr after each tick form tickCh. // refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) { func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer log.OnPanic("systemResolvers") defer log.OnPanic("systemResolvers")

View File

@ -32,6 +32,18 @@ type systemResolvers struct {
addrsLock sync.RWMutex addrsLock sync.RWMutex
} }
const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
func (sr *systemResolvers) refresh() (err error) { func (sr *systemResolvers) refresh() (err error) {
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }() defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()

View File

@ -45,6 +45,57 @@ func (sr *systemResolvers) Get() (rs []string) {
return rs return rs
} }
// writeExit writes "exit" to w and closes it. It is supposed to be run in
// a goroutine.
func writeExit(w io.WriteCloser) {
defer log.OnPanic("systemResolvers: writeExit")
defer func() {
derr := w.Close()
if derr != nil {
log.Error("systemResolvers: writeExit: closing: %s", derr)
}
}()
_, err := io.WriteString(w, "exit")
if err != nil {
log.Error("systemResolvers: writeExit: writing: %s", err)
}
}
// scanAddrs scans the DNS addresses from nslookup's output. The expected
// output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
//
func scanAddrs(s *bufio.Scanner) (addrs []string) {
for s.Scan() {
line := strings.TrimSpace(s.Text())
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "Address:" {
continue
}
// If the address contains port then it is separated with '#'.
ipPort := strings.Split(fields[1], "#")
if len(ipPort) == 0 {
continue
}
addr := ipPort[0]
if net.ParseIP(addr) == nil {
log.Debug("systemResolvers: %q is not a valid ip", addr)
continue
}
addrs = append(addrs, addr)
}
return addrs
}
// getAddrs gets local resolvers' addresses from OS in a special Windows way. // getAddrs gets local resolvers' addresses from OS in a special Windows way.
// //
// TODO(e.burkov): This whole function needs more detailed research on getting // TODO(e.burkov): This whole function needs more detailed research on getting
@ -71,73 +122,30 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
return nil, fmt.Errorf("limiting stdout reader: %w", err) return nil, fmt.Errorf("limiting stdout reader: %w", err)
} }
go func() { go writeExit(stdin)
defer log.OnPanic("systemResolvers")
defer func() {
derr := stdin.Close()
if derr != nil {
log.Error("systemResolvers: closing stdin pipe: %s", derr)
}
}()
_, werr := io.WriteString(stdin, "exit")
if werr != nil {
log.Error("systemResolvers: writing to command pipe: %s", werr)
}
}()
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
return nil, fmt.Errorf("start command executing: %w", err) return nil, fmt.Errorf("start command executing: %w", err)
} }
// The output of nslookup looks like this:
//
// Default Server: 192-168-1-1.qualified.domain.ru
// Address: 192.168.1.1
var possibleIPs []string
s := bufio.NewScanner(stdoutLimited) s := bufio.NewScanner(stdoutLimited)
for s.Scan() { addrs = scanAddrs(s)
line := s.Text()
if len(line) == 0 {
continue
}
fields := strings.Fields(line)
if len(fields) != 2 || fields[0] != "Address:" {
continue
}
// If the address contains port then it is separated with '#'.
ipStrs := strings.Split(fields[1], "#")
if len(ipStrs) == 0 {
continue
}
possibleIPs = append(possibleIPs, ipStrs[0])
}
err = cmd.Wait() err = cmd.Wait()
if err != nil { if err != nil {
return nil, fmt.Errorf("executing the command: %w", err) return nil, fmt.Errorf("executing the command: %w", err)
} }
err = s.Err()
if err != nil {
return nil, fmt.Errorf("scanning output: %w", err)
}
// Don't close StdoutPipe since Wait do it for us in ¿most? cases. // Don't close StdoutPipe since Wait do it for us in ¿most? cases.
// //
// See go doc os/exec.Cmd.StdoutPipe. // See go doc os/exec.Cmd.StdoutPipe.
for _, addr := range possibleIPs {
if net.ParseIP(addr) == nil {
log.Debug("systemResolvers: %q is not a valid ip", addr)
continue
}
addrs = append(addrs, addr)
}
return addrs, nil return addrs, nil
} }

View File

@ -13,8 +13,6 @@ import (
// TestUpstream is a mock of real upstream. // TestUpstream is a mock of real upstream.
type TestUpstream struct { type TestUpstream struct {
// Addr is the address for Address method.
Addr string
// CName is a map of hostname to canonical name. // CName is a map of hostname to canonical name.
CName map[string]string CName map[string]string
// IPv4 is a map of hostname to IPv4. // IPv4 is a map of hostname to IPv4.
@ -23,9 +21,13 @@ type TestUpstream struct {
IPv6 map[string][]net.IP IPv6 map[string][]net.IP
// Reverse is a map of address to domain name. // Reverse is a map of address to domain name.
Reverse map[string][]string Reverse map[string][]string
// Addr is the address for Address method.
Addr string
} }
// Exchange implements upstream.Upstream interface for *TestUpstream. // Exchange implements upstream.Upstream interface for *TestUpstream.
//
// TODO(a.garipov): Split further into handlers.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{} resp = &dns.Msg{}
resp.SetReply(m) resp.SetReply(m)
@ -33,70 +35,69 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
if len(m.Question) == 0 { if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty") return nil, fmt.Errorf("question should not be empty")
} }
name := m.Question[0].Name name := m.Question[0].Name
if cname, ok := u.CName[name]; ok { if cname, ok := u.CName[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{ ans := &dns.CNAME{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: name, Name: name,
Rrtype: dns.TypeCNAME, Rrtype: dns.TypeCNAME,
}, },
Target: cname, Target: cname,
}) }
resp.Answer = append(resp.Answer, ans)
} }
var hasRec bool rrType := m.Question[0].Qtype
var rrType uint16 hdr := dns.RR_Header{
Name: name,
Rrtype: rrType,
}
var names []string
var ips []net.IP var ips []net.IP
switch m.Question[0].Qtype { switch m.Question[0].Qtype {
case dns.TypeA: case dns.TypeA:
rrType = dns.TypeA ips = u.IPv4[name]
if ipv4addr, ok := u.IPv4[name]; ok {
hasRec = true
ips = ipv4addr
}
case dns.TypeAAAA: case dns.TypeAAAA:
rrType = dns.TypeAAAA ips = u.IPv6[name]
if ipv6addr, ok := u.IPv6[name]; ok {
hasRec = true
ips = ipv6addr
}
case dns.TypePTR: case dns.TypePTR:
names, ok := u.Reverse[name] names = u.Reverse[name]
if !ok {
break
}
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: n,
Rrtype: rrType,
},
Ptr: n,
})
}
} }
for _, ip := range ips { for _, ip := range ips {
resp.Answer = append(resp.Answer, &dns.A{ var ans dns.RR
Hdr: dns.RR_Header{ if rrType == dns.TypeA {
Name: name, ans = &dns.A{
Rrtype: rrType, Hdr: hdr,
}, A: ip,
A: ip, }
})
resp.Answer = append(resp.Answer, ans)
continue
}
ans = &dns.AAAA{
Hdr: hdr,
AAAA: ip,
}
resp.Answer = append(resp.Answer, ans)
}
for _, n := range names {
ans := &dns.PTR{
Hdr: hdr,
Ptr: n,
}
resp.Answer = append(resp.Answer, ans)
} }
if len(resp.Answer) == 0 { if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError) resp.SetRcode(m, dns.RcodeNameError)
} }
@ -111,10 +112,13 @@ func (u *TestUpstream) Address() string {
// TestBlockUpstream implements upstream.Upstream interface for replacing real // TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests. // upstream in tests.
type TestBlockUpstream struct { type TestBlockUpstream struct {
Hostname string Hostname string
Block bool
requestsCount int // lock protects reqNum.
lock sync.RWMutex lock sync.RWMutex
reqNum int
Block bool
} }
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block // Exchange returns a message unique for TestBlockUpstream's Hostname-Block
@ -122,7 +126,7 @@ type TestBlockUpstream struct {
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
u.requestsCount++ u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname)) hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:]) hashToReturn := hex.EncodeToString(hash[:])
@ -156,7 +160,7 @@ func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock() u.lock.Lock()
defer u.lock.Unlock() defer u.lock.Unlock()
return u.requestsCount return u.reqNum
} }
// TestErrUpstream implements upstream.Upstream interface for replacing real // TestErrUpstream implements upstream.Upstream interface for replacing real

View File

@ -326,7 +326,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
require.Len(t, pctx.Res.Answer, tc.wantLen) require.Len(t, pctx.Res.Answer, tc.wantLen)
if tc.wantLen > 0 { if tc.wantLen > 0 {
assert.Equal(t, tc.want, pctx.Res.Answer[0].Header().Name) assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr)
} }
}) })
} }
@ -368,7 +368,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
require.Equal(t, resultCodeSuccess, rc) require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer) require.NotEmpty(t, proxyCtx.Res.Answer)
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].Header().Name) assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr)
}) })
t.Run("disabled", func(t *testing.T) { t.Run("disabled", func(t *testing.T) {

View File

@ -284,7 +284,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
StartTime: time.Now(), StartTime: time.Now(),
} }
var resolver *proxy.Proxy = s.internalProxy resolver := s.internalProxy
if s.subnetDetector.IsLocallyServedNetwork(ip) { if s.subnetDetector.IsLocallyServedNetwork(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", nil return "", nil

View File

@ -175,8 +175,15 @@ golint --set_exit_status ./...
"$GO" vet ./... "$GO" vet ./...
# Here and below, don't use quotes to get word splitting. # Apply more lax standards to the code we haven't properly refactored yet.
gocyclo --over 17 $go_files gocyclo --over 17 ./internal/dhcpd/ ./internal/dnsforward/\
./internal/filtering/ ./internal/home/ ./internal/querylog/\
./internal/stats/ ./internal/updater/
# Apply stricter standards to new or vetted code
gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\
./internal/aghstrings/ ./internal/aghtest/ ./internal/tools/\
./internal/version/ ./main.go
gosec --quiet $go_files gosec --quiet $go_files