package dnsforward import ( "strings" "testing" "github.com/go-test/deep" "github.com/miekg/dns" ) func RR(rr string) dns.RR { r, err := dns.NewRR(rr) if err != nil { panic(err) } return r } // deepEqual is same as deep.Equal, except: // * ignores Id when comparing // * question names are not case sensetive func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string { temp := *left temp.Id = right.Id for i := range left.Question { left.Question[i].Name = strings.ToLower(left.Question[i].Name) } for i := range right.Question { right.Question[i].Name = strings.ToLower(right.Question[i].Name) } return deep.Equal(&temp, right) } func TestCacheSanity(t *testing.T) { cache := cache{} request := dns.Msg{} request.SetQuestion("google.com.", dns.TypeA) _, ok := cache.Get(&request) if ok { t.Fatal("empty cache replied with positive response") } } type tests struct { cache []testEntry cases []testCase } type testEntry struct { q string t uint16 a []dns.RR } type testCase struct { q string t uint16 a []dns.RR ok bool } func TestCache(t *testing.T) { tests := tests{ cache: []testEntry{ {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, }, cases: []testCase{ {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, {q: "google.com.", t: dns.TypeMX, ok: false}, }, } runTests(t, tests) } func TestCacheMixedCase(t *testing.T) { tests := tests{ cache: []testEntry{ {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, }, cases: []testCase{ {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, {q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, {q: "gOOgle.com.", t: dns.TypeMX, ok: false}, {q: "google.com.", t: dns.TypeMX, ok: false}, {q: "GOOGLE.COM.", t: dns.TypeMX, ok: false}, }, } runTests(t, tests) } func TestZeroTTL(t *testing.T) { tests := tests{ cache: []testEntry{ {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}}, }, cases: []testCase{ {q: "google.com.", t: dns.TypeA, ok: false}, {q: "google.com.", t: dns.TypeA, ok: false}, {q: "google.com.", t: dns.TypeA, ok: false}, {q: "google.com.", t: dns.TypeMX, ok: false}, {q: "google.com.", t: dns.TypeMX, ok: false}, {q: "google.com.", t: dns.TypeMX, ok: false}, }, } runTests(t, tests) } func runTests(t *testing.T, tests tests) { t.Helper() cache := cache{} for _, tc := range tests.cache { reply := dns.Msg{} reply.SetQuestion(tc.q, tc.t) reply.Response = true reply.Answer = tc.a cache.Set(&reply) } for _, tc := range tests.cases { request := dns.Msg{} request.SetQuestion(tc.q, tc.t) val, ok := cache.Get(&request) if diff := deep.Equal(ok, tc.ok); diff != nil { t.Error(diff) } if tc.a != nil { if ok == false { continue } reply := dns.Msg{} reply.SetQuestion(tc.q, tc.t) reply.Response = true reply.Answer = tc.a cache.Set(&reply) if diff := deepEqualMsg(val, &reply); diff != nil { t.Error(diff) } else { if diff := deep.Equal(val, reply); diff == nil { t.Error("different message ID were not caught") } } } } }