AdGuardHome/dnsforward/cache_test.go

145 lines
3.3 KiB
Go

package dnsforward
import (
"strings"
"testing"
"github.com/go-test/deep"
"github.com/miekg/dns"
)
func RR(rr string) dns.RR {
r, err := dns.NewRR(rr)
if err != nil {
panic(err)
}
return r
}
// deepEqual is same as deep.Equal, except:
// * ignores Id when comparing
// * question names are not case sensetive
func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string {
temp := *left
temp.Id = right.Id
for i := range left.Question {
left.Question[i].Name = strings.ToLower(left.Question[i].Name)
}
for i := range right.Question {
right.Question[i].Name = strings.ToLower(right.Question[i].Name)
}
return deep.Equal(&temp, right)
}
func TestCacheSanity(t *testing.T) {
cache := cache{}
request := dns.Msg{}
request.SetQuestion("google.com.", dns.TypeA)
_, ok := cache.Get(&request)
if ok {
t.Fatal("empty cache replied with positive response")
}
}
type tests struct {
cache []testEntry
cases []testCase
}
type testEntry struct {
q string
t uint16
a []dns.RR
}
type testCase struct {
q string
t uint16
a []dns.RR
ok bool
}
func TestCache(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "google.com.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func TestCacheMixedCase(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true},
{q: "gOOgle.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "GOOGLE.COM.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func TestZeroTTL(t *testing.T) {
tests := tests{
cache: []testEntry{
{q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}},
},
cases: []testCase{
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeA, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
{q: "google.com.", t: dns.TypeMX, ok: false},
},
}
runTests(t, tests)
}
func runTests(t *testing.T, tests tests) {
t.Helper()
cache := cache{}
for _, tc := range tests.cache {
reply := dns.Msg{}
reply.SetQuestion(tc.q, tc.t)
reply.Response = true
reply.Answer = tc.a
cache.Set(&reply)
}
for _, tc := range tests.cases {
request := dns.Msg{}
request.SetQuestion(tc.q, tc.t)
val, ok := cache.Get(&request)
if diff := deep.Equal(ok, tc.ok); diff != nil {
t.Error(diff)
}
if tc.a != nil {
if ok == false {
continue
}
reply := dns.Msg{}
reply.SetQuestion(tc.q, tc.t)
reply.Response = true
reply.Answer = tc.a
cache.Set(&reply)
if diff := deepEqualMsg(val, &reply); diff != nil {
t.Error(diff)
} else {
if diff := deep.Equal(val, reply); diff == nil {
t.Error("different message ID were not caught")
}
}
}
}
}