From a3dddd72c1cd864749355c49ef5313a447709638 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Tue, 9 Feb 2021 19:38:31 +0300 Subject: [PATCH] Pull request: 2639 use testify require vol.2 Merge in DNS/adguard-home from 2639-testify-require-2 to master Updates #2639. Squashed commit of the following: commit 31cc29a166e2e48a73956853cbc6d6dd681ab6da Author: Eugene Burkov Date: Tue Feb 9 18:48:31 2021 +0300 all: deal with t.Run commit 484f477fbfedd03aca4d322bc1cc9e131f30e1ce Author: Eugene Burkov Date: Tue Feb 9 17:44:02 2021 +0300 all: fix readability, imp tests commit 1231a825b353c16e43eae1b660dbb4c87805f564 Author: Eugene Burkov Date: Tue Feb 9 16:06:29 2021 +0300 all: imp tests --- internal/aghtest/os.go | 45 ++++++ internal/querylog/qlog_test.go | 43 +----- internal/querylog/qlogfile_test.go | 17 ++- internal/querylog/qlogreader_test.go | 2 +- internal/stats/stats_test.go | 130 +++++++++-------- internal/stats/unit.go | 113 ++++++++------- internal/sysutil/net_linux_test.go | 43 +++--- internal/util/autohosts_test.go | 200 +++++++++++++++++---------- internal/util/helpers_test.go | 4 +- internal/util/network_test.go | 17 +-- 10 files changed, 345 insertions(+), 269 deletions(-) create mode 100644 internal/aghtest/os.go diff --git a/internal/aghtest/os.go b/internal/aghtest/os.go new file mode 100644 index 00000000..a9885c03 --- /dev/null +++ b/internal/aghtest/os.go @@ -0,0 +1,45 @@ +package aghtest + +import ( + "io/ioutil" + "os" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// PrepareTestDir returns the full path to temporary created directory and +// registers the appropriate cleanup for *t. +func PrepareTestDir(t *testing.T) (dir string) { + t.Helper() + + wd, err := os.Getwd() + require.Nil(t, err) + + dir, err = ioutil.TempDir(wd, "agh-test") + require.Nil(t, err) + require.NotEmpty(t, dir) + + t.Cleanup(func() { + // TODO(e.burkov): Replace with t.TempDir methods after updating + // go version to 1.15. + start := time.Now() + for { + err := os.RemoveAll(dir) + if err == nil { + break + } + + if runtime.GOOS != "windows" || time.Since(start) >= 500*time.Millisecond { + break + } + time.Sleep(5 * time.Millisecond) + } + assert.Nil(t, err) + }) + + return dir +} diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index ffe12c4f..c8b34eb5 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -2,11 +2,8 @@ package querylog import ( "fmt" - "io/ioutil" "math/rand" "net" - "os" - "runtime" "sort" "testing" "time" @@ -24,38 +21,6 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -func prepareTestDir(t *testing.T) string { - t.Helper() - - wd, err := os.Getwd() - require.Nil(t, err) - - dir, err := ioutil.TempDir(wd, "agh-tests") - require.Nil(t, err) - require.NotEmpty(t, dir) - - t.Cleanup(func() { - // TODO(e.burkov): Replace with t.TempDir methods after updating - // go version to 1.15. - start := time.Now() - for { - err := os.RemoveAll(dir) - if err == nil { - break - } - - if runtime.GOOS != "windows" || time.Since(start) >= 500*time.Millisecond { - break - } - time.Sleep(5 * time.Millisecond) - } - - assert.Nil(t, err) - }) - - return dir -} - // TestQueryLog tests adding and loading (with filtering) entries from disk and // memory. func TestQueryLog(t *testing.T) { @@ -64,7 +29,7 @@ func TestQueryLog(t *testing.T) { FileEnabled: true, Interval: 1, MemSize: 100, - BaseDir: prepareTestDir(t), + BaseDir: aghtest.PrepareTestDir(t), }) // Add disk entries. @@ -166,7 +131,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { Enabled: true, Interval: 1, MemSize: 100, - BaseDir: prepareTestDir(t), + BaseDir: aghtest.PrepareTestDir(t), }) const ( @@ -240,7 +205,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { FileEnabled: true, Interval: 1, MemSize: 100, - BaseDir: prepareTestDir(t), + BaseDir: aghtest.PrepareTestDir(t), }) const entNum = 10 @@ -268,7 +233,7 @@ func TestQueryLogFileDisabled(t *testing.T) { FileEnabled: false, Interval: 1, MemSize: 2, - BaseDir: prepareTestDir(t), + BaseDir: aghtest.PrepareTestDir(t), }) addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index b583d8da..b74111fc 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -12,15 +12,20 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // prepareTestFiles prepares several test query log files, each with the // specified lines count. -func prepareTestFiles(t *testing.T, dir string, filesNum, linesNum int) []string { +func prepareTestFiles(t *testing.T, filesNum, linesNum int) []string { t.Helper() + if filesNum == 0 { + return []string{} + } + const strV = "\"%s\"" const nl = "\n" const format = `{"IP":` + strV + `,"T":` + strV + `,` + @@ -31,6 +36,8 @@ func prepareTestFiles(t *testing.T, dir string, filesNum, linesNum int) []string lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00") lineIP := uint32(0) + dir := aghtest.PrepareTestDir(t) + files := make([]string, filesNum) for j := range files { f, err := ioutil.TempFile(dir, "*.txt") @@ -56,10 +63,10 @@ func prepareTestFiles(t *testing.T, dir string, filesNum, linesNum int) []string // prepareTestFile prepares a test query log file with the specified number of // lines. -func prepareTestFile(t *testing.T, dir string, linesCount int) string { +func prepareTestFile(t *testing.T, linesCount int) string { t.Helper() - return prepareTestFiles(t, dir, 1, linesCount)[0] + return prepareTestFiles(t, 1, linesCount)[0] } // newTestQLogFile creates new *QLogFile for tests and registers the required @@ -67,7 +74,7 @@ func prepareTestFile(t *testing.T, dir string, linesCount int) string { func newTestQLogFile(t *testing.T, linesNum int) (file *QLogFile) { t.Helper() - testFile := prepareTestFile(t, prepareTestDir(t), linesNum) + testFile := prepareTestFile(t, linesNum) // Create the new QLogFile instance. file, err := NewQLogFile(testFile) @@ -275,7 +282,7 @@ func TestQLogFile(t *testing.T) { } func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) { - f, err := ioutil.TempFile(prepareTestDir(t), "*.txt") + f, err := ioutil.TempFile(aghtest.PrepareTestDir(t), "*.txt") require.Nil(t, err) t.Cleanup(func() { assert.Nil(t, f.Close()) diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index 0333be36..07fd6fd3 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -15,7 +15,7 @@ import ( func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader) { t.Helper() - testFiles := prepareTestFiles(t, prepareTestDir(t), filesNum, linesNum) + testFiles := prepareTestFiles(t, filesNum, linesNum) // Create the new QLogReader instance. reader, err := NewQLogReader(testFiles) diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 7643b31a..e58f198c 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -34,28 +35,30 @@ func TestStats(t *testing.T) { Filename: "./stats.db", LimitDays: 1, } + + s, err := createObject(conf) + require.Nil(t, err) t.Cleanup(func() { + s.clear() + s.Close() assert.Nil(t, os.Remove(conf.Filename)) }) - s, _ := createObject(conf) - - e := Entry{} - - e.Domain = "domain" - e.Client = "127.0.0.1" - e.Result = RFiltered - e.Time = 123456 - s.Update(e) - - e.Domain = "domain" - e.Client = "127.0.0.1" - e.Result = RNotFiltered - e.Time = 123456 - s.Update(e) + s.Update(Entry{ + Domain: "domain", + Client: "127.0.0.1", + Result: RFiltered, + Time: 123456, + }) + s.Update(Entry{ + Domain: "domain", + Client: "127.0.0.1", + Result: RNotFiltered, + Time: 123456, + }) d, ok := s.getData() - assert.True(t, ok) + require.True(t, ok) a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2} assert.True(t, UIntArrayEquals(d.DNSQueries, a)) @@ -70,12 +73,15 @@ func TestStats(t *testing.T) { assert.True(t, UIntArrayEquals(d.ReplacedParental, a)) m := d.TopQueried + require.NotEmpty(t, m) assert.EqualValues(t, 1, m[0]["domain"]) m = d.TopBlocked + require.NotEmpty(t, m) assert.EqualValues(t, 1, m[0]["domain"]) m = d.TopClients + require.NotEmpty(t, m) assert.EqualValues(t, 2, m[0]["127.0.0.1"]) assert.EqualValues(t, 2, d.NumDNSQueries) @@ -86,81 +92,69 @@ func TestStats(t *testing.T) { assert.EqualValues(t, 0.123456, d.AvgProcessingTime) topClients := s.GetTopClientsIP(2) + require.NotEmpty(t, topClients) assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0])) - - s.clear() - s.Close() } func TestLargeNumbers(t *testing.T) { - var hour int32 = 1 + var hour int32 = 0 newID := func() uint32 { - // use "atomic" to make Go race detector happy + // Use "atomic" to make go race detector happy. return uint32(atomic.LoadInt32(&hour)) } - // log.SetLevel(log.DEBUG) conf := Config{ Filename: "./stats.db", LimitDays: 1, UnitID: newID, } + s, err := createObject(conf) + require.Nil(t, err) t.Cleanup(func() { + s.Close() assert.Nil(t, os.Remove(conf.Filename)) }) - s, _ := createObject(conf) - e := Entry{} + // Number of distinct clients and domains every hour. + const n = 1000 - n := 1000 // number of distinct clients and domains every hour - for h := 0; h != 12; h++ { - if h != 0 { - atomic.AddInt32(&hour, 1) - } - for i := 0; i != n; i++ { - e.Domain = fmt.Sprintf("domain%d", i) - ip := net.IP{127, 0, 0, 1} - ip[2] = byte((i & 0xff00) >> 8) - ip[3] = byte(i & 0xff) - e.Client = ip.String() - e.Result = RNotFiltered - e.Time = 123456 - s.Update(e) + for h := 0; h < 12; h++ { + atomic.AddInt32(&hour, 1) + for i := 0; i < n; i++ { + s.Update(Entry{ + Domain: fmt.Sprintf("domain%d", i), + Client: net.IP{ + 127, + 0, + byte((i & 0xff00) >> 8), + byte(i & 0xff), + }.String(), + Result: RNotFiltered, + Time: 123456, + }) } } d, ok := s.getData() - assert.True(t, ok) - assert.EqualValues(t, int(hour)*n, d.NumDNSQueries) - - s.Close() + require.True(t, ok) + assert.EqualValues(t, hour*n, d.NumDNSQueries) } -// this code is a chunk copied from getData() that generates aggregate data per day -func aggregateDataPerDay(firstID uint32) int { - firstDayID := (firstID + 24 - 1) / 24 * 24 // align_ceil(24) - a := []uint64{} - var sum uint64 - id := firstDayID - nextDayID := firstDayID + 24 - for i := firstDayID - firstID; int(i) != 720; i++ { - sum++ - if id == nextDayID { - a = append(a, sum) - sum = 0 - nextDayID += 24 +func TestStatsCollector(t *testing.T) { + ng := func(_ *unitDB) uint64 { + return 0 + } + units := make([]*unitDB, 720) + + t.Run("hours", func(t *testing.T) { + statsData := statsCollector(units, 0, Hours, ng) + assert.Len(t, statsData, 720) + }) + + t.Run("days", func(t *testing.T) { + for i := 0; i != 25; i++ { + statsData := statsCollector(units, uint32(i), Days, ng) + require.Lenf(t, statsData, 30, "i=%d", i) } - id++ - } - if id <= nextDayID { - a = append(a, sum) - } - return len(a) -} - -func TestAggregateDataPerTimeUnit(t *testing.T) { - for i := 0; i != 25; i++ { - alen := aggregateDataPerDay(uint32(i)) - assert.Equalf(t, 30, alen, "i=%d", i) - } + }) } diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 6f31cd5e..d955d04f 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -528,6 +528,57 @@ func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { return units, firstID } +// numsGetter is a signature for statsCollector argument. +type numsGetter func(u *unitDB) (num uint64) + +// statsCollector collects statisctics for the given *unitDB slice by specified +// timeUnit using ng to retrieve data. +func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsGetter) (nums []uint64) { + if timeUnit == Hours { + for _, u := range units { + nums = append(nums, ng(u)) + } + } else { + // Per time unit counters: 720 hours may span 31 days, so we + // skip data for the first day in this case. + // align_ceil(24) + firstDayID := (firstID + 24 - 1) / 24 * 24 + + var sum uint64 + id := firstDayID + nextDayID := firstDayID + 24 + for i := int(firstDayID - firstID); i != len(units); i++ { + sum += ng(units[i]) + if id == nextDayID { + nums = append(nums, sum) + sum = 0 + nextDayID += 24 + } + id++ + } + if id <= nextDayID { + nums = append(nums, sum) + } + } + return nums +} + +// pairsGetter is a signature for topsCollector argument. +type pairsGetter func(u *unitDB) (pairs []countPair) + +// topsCollector collects statistics about highest values fro the given *unitDB +// slice using pg to retrieve data. +func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 { + m := map[string]uint64{} + for _, u := range units { + for _, it := range pg(u) { + m[it.Name] += it.Count + } + } + a2 := convertMapToSlice(m, max) + return convertTopSlice(a2) +} + /* Algorithm: . Prepare array of N units, where N is the value of "limit" configuration setting . Load data for the most recent units from file @@ -568,65 +619,25 @@ func (s *statsCtx) getData() (statsResponse, bool) { return statsResponse{}, false } - // per time unit counters: - // 720 hours may span 31 days, so we skip data for the first day in this case - firstDayID := (firstID + 24 - 1) / 24 * 24 // align_ceil(24) - - statsCollector := func(numsGetter func(u *unitDB) (num uint64)) (nums []uint64) { - if timeUnit == Hours { - for _, u := range units { - nums = append(nums, numsGetter(u)) - } - } else { - var sum uint64 - id := firstDayID - nextDayID := firstDayID + 24 - for i := int(firstDayID - firstID); i != len(units); i++ { - sum += numsGetter(units[i]) - if id == nextDayID { - nums = append(nums, sum) - sum = 0 - nextDayID += 24 - } - id++ - } - if id <= nextDayID { - nums = append(nums, sum) - } - } - return nums - } - - topsCollector := func(max int, pairsGetter func(u *unitDB) (pairs []countPair)) []map[string]uint64 { - m := map[string]uint64{} - for _, u := range units { - for _, it := range pairsGetter(u) { - m[it.Name] += it.Count - } - } - a2 := convertMapToSlice(m, max) - return convertTopSlice(a2) - } - - dnsQueries := statsCollector(func(u *unitDB) (num uint64) { return u.NTotal }) + dnsQueries := statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NTotal }) if timeUnit != Hours && len(dnsQueries) != int(limit/24) { log.Fatalf("len(dnsQueries) != limit: %d %d", len(dnsQueries), limit) } data := statsResponse{ DNSQueries: dnsQueries, - BlockedFiltering: statsCollector(func(u *unitDB) (num uint64) { return u.NResult[RFiltered] }), - ReplacedSafebrowsing: statsCollector(func(u *unitDB) (num uint64) { return u.NResult[RSafeBrowsing] }), - ReplacedParental: statsCollector(func(u *unitDB) (num uint64) { return u.NResult[RParental] }), - TopQueried: topsCollector(maxDomains, func(u *unitDB) (pairs []countPair) { return u.Domains }), - TopBlocked: topsCollector(maxDomains, func(u *unitDB) (pairs []countPair) { return u.BlockedDomains }), - TopClients: topsCollector(maxClients, func(u *unitDB) (pairs []countPair) { return u.Clients }), + BlockedFiltering: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RFiltered] }), + ReplacedSafebrowsing: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RSafeBrowsing] }), + ReplacedParental: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RParental] }), + TopQueried: topsCollector(units, maxDomains, func(u *unitDB) (pairs []countPair) { return u.Domains }), + TopBlocked: topsCollector(units, maxDomains, func(u *unitDB) (pairs []countPair) { return u.BlockedDomains }), + TopClients: topsCollector(units, maxClients, func(u *unitDB) (pairs []countPair) { return u.Clients }), } - // total counters: - - sum := unitDB{} - sum.NResult = make([]uint64, rLast) + // Total counters: + sum := unitDB{ + NResult: make([]uint64, rLast), + } timeN := 0 for _, u := range units { sum.NTotal += u.NTotal diff --git a/internal/sysutil/net_linux_test.go b/internal/sysutil/net_linux_test.go index a9851cb2..3fbfc547 100644 --- a/internal/sysutil/net_linux_test.go +++ b/internal/sysutil/net_linux_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const nl = "\n" @@ -48,7 +49,7 @@ func TestDHCPCDStaticConfig(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := bytes.NewReader(tc.data) has, err := dhcpcdStaticConfig(r, "wlan0") - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, tc.want, has) }) } @@ -85,26 +86,36 @@ func TestIfacesStaticConfig(t *testing.T) { t.Run(tc.name, func(t *testing.T) { r := bytes.NewReader(tc.data) has, err := ifacesStaticConfig(r, "enp0s3") - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, tc.want, has) }) } } func TestSetStaticIPdhcpcdConf(t *testing.T) { - dhcpcdConf := nl + `interface wlan0` + nl + - `static ip_address=192.168.0.2/24` + nl + - `static routers=192.168.0.1` + nl + - `static domain_name_servers=192.168.0.2` + nl + nl + testCases := []struct { + name string + dhcpcdConf string + routers net.IP + }{{ + name: "with_gateway", + dhcpcdConf: nl + `interface wlan0` + nl + + `static ip_address=192.168.0.2/24` + nl + + `static routers=192.168.0.1` + nl + + `static domain_name_servers=192.168.0.2` + nl + nl, + routers: net.IP{192, 168, 0, 1}, + }, { + name: "without_gateway", + dhcpcdConf: nl + `interface wlan0` + nl + + `static ip_address=192.168.0.2/24` + nl + + `static domain_name_servers=192.168.0.2` + nl + nl, + routers: nil, + }} - s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 2}) - assert.Equal(t, dhcpcdConf, s) - - // without gateway - dhcpcdConf = nl + `interface wlan0` + nl + - `static ip_address=192.168.0.2/24` + nl + - `static domain_name_servers=192.168.0.2` + nl + nl - - s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", nil, net.IP{192, 168, 0, 2}) - assert.Equal(t, dhcpcdConf, s) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", tc.routers, net.IP{192, 168, 0, 2}) + assert.Equal(t, tc.dhcpcdConf, s) + }) + } } diff --git a/internal/util/autohosts_test.go b/internal/util/autohosts_test.go index c5f26934..82e16da9 100644 --- a/internal/util/autohosts_test.go +++ b/internal/util/autohosts_test.go @@ -11,114 +11,162 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -func prepareTestDir() string { - const dir = "./agh-test" - _ = os.RemoveAll(dir) - _ = os.MkdirAll(dir, 0o755) - return dir +func prepareTestFile(t *testing.T) (f *os.File) { + t.Helper() + + dir := aghtest.PrepareTestDir(t) + + f, err := ioutil.TempFile(dir, "") + require.Nil(t, err) + require.NotNil(t, f) + t.Cleanup(func() { + assert.Nil(t, f.Close()) + }) + + return f +} + +func assertWriting(t *testing.T, f *os.File, strs ...string) { + t.Helper() + + for _, str := range strs { + n, err := f.WriteString(str) + require.Nil(t, err) + assert.Equal(t, n, len(str)) + } } func TestAutoHostsResolution(t *testing.T) { - ah := AutoHosts{} + ah := &AutoHosts{} - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() - - f, _ := ioutil.TempFile(dir, "") - defer func() { _ = os.Remove(f.Name()) }() - defer f.Close() - - _, _ = f.WriteString(" 127.0.0.1 host localhost # comment \n") - _, _ = f.WriteString(" ::1 localhost#comment \n") + f := prepareTestFile(t) + assertWriting(t, f, + " 127.0.0.1 host localhost # comment \n", + " ::1 localhost#comment \n", + ) ah.Init(f.Name()) - // Existing host - ips := ah.Process("localhost", dns.TypeA) - assert.NotNil(t, ips) - assert.Len(t, ips, 1) - assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0]) + t.Run("existing_host", func(t *testing.T) { + ips := ah.Process("localhost", dns.TypeA) + require.Len(t, ips, 1) + assert.Equal(t, net.IPv4(127, 0, 0, 1), ips[0]) + }) - // Unknown host - ips = ah.Process("newhost", dns.TypeA) - assert.Nil(t, ips) + t.Run("unknown_host", func(t *testing.T) { + ips := ah.Process("newhost", dns.TypeA) + assert.Nil(t, ips) - // Unknown host (comment) - ips = ah.Process("comment", dns.TypeA) - assert.Nil(t, ips) + // Comment. + ips = ah.Process("comment", dns.TypeA) + assert.Nil(t, ips) + }) - // Test hosts file - table := ah.List() - names, ok := table["127.0.0.1"] - assert.True(t, ok) - assert.Equal(t, []string{"host", "localhost"}, names) + t.Run("hosts_file", func(t *testing.T) { + names, ok := ah.List()["127.0.0.1"] + require.True(t, ok) + assert.Equal(t, []string{"host", "localhost"}, names) + }) - // Test PTR - a, _ := dns.ReverseAddr("127.0.0.1") - a = strings.TrimSuffix(a, ".") - hosts := ah.ProcessReverse(a, dns.TypePTR) - if assert.Len(t, hosts, 2) { - assert.Equal(t, hosts[0], "host") - } + t.Run("ptr", func(t *testing.T) { + testCases := []struct { + wantIP string + wantLen int + wantHost string + }{ + {wantIP: "127.0.0.1", wantLen: 2, wantHost: "host"}, + {wantIP: "::1", wantLen: 1, wantHost: "localhost"}, + } - a, _ = dns.ReverseAddr("::1") - a = strings.TrimSuffix(a, ".") - hosts = ah.ProcessReverse(a, dns.TypePTR) - if assert.Len(t, hosts, 1) { - assert.Equal(t, hosts[0], "localhost") - } + for _, tc := range testCases { + a, err := dns.ReverseAddr(tc.wantIP) + require.Nil(t, err) + + a = strings.TrimSuffix(a, ".") + hosts := ah.ProcessReverse(a, dns.TypePTR) + require.Len(t, hosts, tc.wantLen) + assert.Equal(t, tc.wantHost, hosts[0]) + } + }) } func TestAutoHostsFSNotify(t *testing.T) { - ah := AutoHosts{} + ah := &AutoHosts{} - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() + f := prepareTestFile(t) - f, _ := ioutil.TempFile(dir, "") - defer func() { _ = os.Remove(f.Name()) }() - defer f.Close() - - // Init - _, _ = f.WriteString(" 127.0.0.1 host localhost \n") + assertWriting(t, f, " 127.0.0.1 host localhost \n") ah.Init(f.Name()) - // Unknown host - ips := ah.Process("newhost", dns.TypeA) - assert.Nil(t, ips) + t.Run("unknown_host", func(t *testing.T) { + ips := ah.Process("newhost", dns.TypeA) + assert.Nil(t, ips) + }) - // Stat monitoring for changes + // Start monitoring for changes. ah.Start() - defer ah.Close() + t.Cleanup(ah.Close) - // Update file - _, _ = f.WriteString("127.0.0.2 newhost\n") - _ = f.Sync() + assertWriting(t, f, "127.0.0.2 newhost\n") + require.Nil(t, f.Sync()) - // wait until fsnotify has triggerred and processed the file-modification event + // Wait until fsnotify has triggerred and processed the + // file-modification event. time.Sleep(50 * time.Millisecond) - // Check if we are notified about changes - ips = ah.Process("newhost", dns.TypeA) - assert.NotNil(t, ips) - assert.Len(t, ips, 1) - assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0])) + t.Run("notified", func(t *testing.T) { + ips := ah.Process("newhost", dns.TypeA) + assert.NotNil(t, ips) + require.Len(t, ips, 1) + assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0])) + }) } -func TestIP(t *testing.T) { - assert.True(t, net.IP{127, 0, 0, 1}.Equal(DNSUnreverseAddr("1.0.0.127.in-addr.arpa"))) - assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) - assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) +func TestDNSReverseAddr(t *testing.T) { + testCases := []struct { + name string + have string + want net.IP + }{{ + name: "good_ipv4", + have: "1.0.0.127.in-addr.arpa", + want: net.IP{127, 0, 0, 1}, + }, { + name: "good_ipv6", + have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + want: net.ParseIP("::abcd:1234"), + }, { + name: "good_ipv6_case", + have: "4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + want: net.ParseIP("::abcd:1234"), + }, { + name: "bad_ipv4_dot", + have: "1.0.0.127.in-addr.arpa.", + }, { + name: "wrong_ipv4", + have: ".0.0.127.in-addr.arpa", + }, { + name: "wrong_ipv6", + have: ".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + }, { + name: "bad_ipv6_dot", + have: "4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa", + }, { + name: "bad_ipv6_space", + have: "4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + }} - assert.Nil(t, DNSUnreverseAddr("1.0.0.127.in-addr.arpa.")) - assert.Nil(t, DNSUnreverseAddr(".0.0.127.in-addr.arpa")) - assert.Nil(t, DNSUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa")) - assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa")) - assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ip := DNSUnreverseAddr(tc.have) + assert.True(t, tc.want.Equal(ip)) + }) + } } diff --git a/internal/util/helpers_test.go b/internal/util/helpers_test.go index 68ebbabd..a09d97e6 100644 --- a/internal/util/helpers_test.go +++ b/internal/util/helpers_test.go @@ -4,12 +4,14 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSplitNext(t *testing.T) { s := " a,b , c " + assert.Equal(t, "a", SplitNext(&s, ',')) assert.Equal(t, "b", SplitNext(&s, ',')) assert.Equal(t, "c", SplitNext(&s, ',')) - assert.Empty(t, s) + require.Empty(t, s) } diff --git a/internal/util/network_test.go b/internal/util/network_test.go index 9b2a9554..bcc1da4c 100644 --- a/internal/util/network_test.go +++ b/internal/util/network_test.go @@ -2,22 +2,15 @@ package util import ( "testing" + + "github.com/stretchr/testify/require" ) func TestGetValidNetInterfacesForWeb(t *testing.T) { ifaces, err := GetValidNetInterfacesForWeb() - if err != nil { - t.Fatalf("Cannot get net interfaces: %s", err) - } - if len(ifaces) == 0 { - t.Fatalf("No net interfaces found") - } - + require.Nilf(t, err, "Cannot get net interfaces: %s", err) + require.NotEmpty(t, ifaces, "No net interfaces found") for _, iface := range ifaces { - if len(iface.Addresses) == 0 { - t.Fatalf("No addresses found for %s", iface.Name) - } - - t.Logf("%v", iface) + require.NotEmptyf(t, iface.Addresses, "No addresses found for %s", iface.Name) } }