package querylog import ( "fmt" "net" "testing" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { testutil.DiscardLogOutput(m) } // TestQueryLog tests adding and loading (with filtering) entries from disk and // memory. func TestQueryLog(t *testing.T) { l, err := newQueryLog(Config{ Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, MemSize: 100, BaseDir: t.TempDir(), }) require.NoError(t, err) // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // Write to disk (first file). require.NoError(t, l.flushLogBuffer()) // Start writing to the second file. require.NoError(t, l.rotate()) // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // Write to disk. require.NoError(t, l.flushLogBuffer()) // Add memory entries. addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "", net.IPv4(1, 1, 1, 5), net.IPv4(2, 2, 2, 5)) type tcAssertion struct { host string answer net.IP client net.IP num int } testCases := []struct { name string sCr []searchCriterion want []tcAssertion }{{ name: "all", sCr: []searchCriterion{}, want: []tcAssertion{ {num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)}, {num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)}, {num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }, { name: "by_domain_strict", sCr: []searchCriterion{{ criterionType: ctTerm, strict: true, value: "TEST.example.org", }}, want: []tcAssertion{{ num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3), }}, }, { name: "by_domain_non-strict", sCr: []searchCriterion{{ criterionType: ctTerm, strict: false, value: "example.ORG", }}, want: []tcAssertion{ {num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 1, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }, { name: "by_client_ip_strict", sCr: []searchCriterion{{ criterionType: ctTerm, strict: true, value: "2.2.2.2", }}, want: []tcAssertion{{ num: 0, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2), }}, }, { name: "by_client_ip_non-strict", sCr: []searchCriterion{{ criterionType: ctTerm, strict: false, value: "2.2.2", }}, want: []tcAssertion{ {num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)}, {num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)}, {num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { params := newSearchParams() params.searchCriteria = tc.sCr entries, _ := l.search(params) require.Len(t, entries, len(tc.want)) for _, want := range tc.want { assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) } }) } } func TestQueryLogOffsetLimit(t *testing.T) { l, err := newQueryLog(Config{ Enabled: true, RotationIvl: timeutil.Day, MemSize: 100, BaseDir: t.TempDir(), }) require.NoError(t, err) const ( entNum = 10 firstPageDomain = "first.example.org" secondPageDomain = "second.example.org" ) // Add entries to the log. for i := 0; i < entNum; i++ { addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to the first file. require.NoError(t, l.flushLogBuffer()) // Add more to the in-memory part of log. for i := 0; i < entNum; i++ { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } params := newSearchParams() testCases := []struct { name string want string wantLen int offset int limit int }{{ name: "page_1", want: firstPageDomain, wantLen: 10, offset: 0, limit: 10, }, { name: "page_2", want: secondPageDomain, wantLen: 10, offset: 10, limit: 10, }, { name: "page_2.5", want: secondPageDomain, wantLen: 5, offset: 15, limit: 10, }, { name: "page_3", want: "", wantLen: 0, offset: 20, limit: 10, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { params.offset = tc.offset params.limit = tc.limit entries, _ := l.search(params) require.Len(t, entries, tc.wantLen) if tc.wantLen > 0 { assert.Equal(t, entries[0].QHost, tc.want) assert.Equal(t, entries[tc.wantLen-1].QHost, tc.want) } }) } } func TestQueryLogMaxFileScanEntries(t *testing.T) { l, err := newQueryLog(Config{ Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, MemSize: 100, BaseDir: t.TempDir(), }) require.NoError(t, err) const entNum = 10 // Add entries to the log. for i := 0; i < entNum; i++ { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to disk. require.NoError(t, l.flushLogBuffer()) params := newSearchParams() for _, maxFileScanEntries := range []int{5, 0} { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { params.maxFileScanEntries = maxFileScanEntries entries, _ := l.search(params) assert.Len(t, entries, entNum-maxFileScanEntries) }) } } func TestQueryLogFileDisabled(t *testing.T) { l, err := newQueryLog(Config{ Enabled: true, FileEnabled: false, RotationIvl: timeutil.Day, MemSize: 2, BaseDir: t.TempDir(), }) require.NoError(t, err) addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // The oldest entry is going to be removed from memory buffer. addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) params := newSearchParams() ll, _ := l.search(params) require.Len(t, ll, 2) assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) } func TestQueryLogShouldLog(t *testing.T) { const ( ignored1 = "ignor.ed" ignored2 = "ignored.to" ) set := stringutil.NewSet(ignored1, ignored2) findClient := func(ids []string) (c *Client, err error) { log := ids[0] == "no_log" return &Client{IgnoreQueryLog: log}, nil } l, err := newQueryLog(Config{ Ignored: set, Enabled: true, RotationIvl: timeutil.Day, MemSize: 100, BaseDir: t.TempDir(), FindClient: findClient, }) require.NoError(t, err) testCases := []struct { name string host string ids []string wantLog bool }{{ name: "log", host: "example.com", ids: []string{"whatever"}, wantLog: true, }, { name: "no_log_ignored_1", host: ignored1, ids: []string{"whatever"}, wantLog: false, }, { name: "no_log_ignored_2", host: ignored2, ids: []string{"whatever"}, wantLog: false, }, { name: "no_log_client_ignore", host: "example.com", ids: []string{"no_log"}, wantLog: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { res := l.ShouldLog(tc.host, dns.TypeA, dns.ClassINET, tc.ids) assert.Equal(t, tc.wantLog, res) }) } } func addEntry(l *queryLog, host string, answerStr, client net.IP) { q := dns.Msg{ Question: []dns.Question{{ Name: host + ".", Qtype: dns.TypeA, Qclass: dns.ClassINET, }}, } a := dns.Msg{ Question: q.Question, Answer: []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Name: q.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, }, A: answerStr, }}, } res := filtering.Result{ ServiceName: "SomeService", Rules: []*filtering.ResultRule{{ FilterListID: 1, Text: "SomeRule", }}, Reason: filtering.Rewritten, IsFiltered: true, } params := &AddParams{ Question: &q, Answer: &a, OrigAnswer: &a, Result: &res, Upstream: "upstream", ClientIP: client, } l.Add(params) } func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) { t.Helper() require.NotNil(t, entry) assert.Equal(t, host, entry.QHost) assert.Equal(t, client, entry.IP) assert.Equal(t, "A", entry.QType) assert.Equal(t, "IN", entry.QClass) msg := &dns.Msg{} require.NoError(t, msg.Unpack(entry.Answer)) require.Len(t, msg.Answer, 1) a := testutil.RequireTypeAssert[*dns.A](t, msg.Answer[0]) assert.Equal(t, answer, a.A.To16()) }