package aghtest import ( "crypto/sha256" "encoding/hex" "errors" "fmt" "net" "strings" "sync" "github.com/miekg/dns" ) // TestUpstream is a mock of real upstream. type TestUpstream struct { // Addr is the address for Address method. Addr string // CName is a map of hostname to canonical name. CName map[string]string // IPv4 is a map of hostname to IPv4. IPv4 map[string][]net.IP // IPv6 is a map of hostname to IPv6. IPv6 map[string][]net.IP // Reverse is a map of address to domain name. Reverse map[string][]string } // Exchange implements upstream.Upstream interface for *TestUpstream. func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) if len(m.Question) == 0 { return nil, fmt.Errorf("question should not be empty") } name := m.Question[0].Name if cname, ok := u.CName[name]; ok { resp.Answer = append(resp.Answer, &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, }, Target: cname, }) } var hasRec bool var rrType uint16 var ips []net.IP switch m.Question[0].Qtype { case dns.TypeA: rrType = dns.TypeA if ipv4addr, ok := u.IPv4[name]; ok { hasRec = true ips = ipv4addr } case dns.TypeAAAA: rrType = dns.TypeAAAA if ipv6addr, ok := u.IPv6[name]; ok { hasRec = true ips = ipv6addr } case dns.TypePTR: names, ok := u.Reverse[name] if !ok { break } for _, n := range names { resp.Answer = append(resp.Answer, &dns.PTR{ Hdr: dns.RR_Header{ Name: name, Rrtype: rrType, }, Ptr: n, }) } } for _, ip := range ips { resp.Answer = append(resp.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: name, Rrtype: rrType, }, A: ip, }) } if len(resp.Answer) == 0 { if hasRec { // Set no error RCode if there are some records for // given Qname but we didn't apply them. resp.SetRcode(m, dns.RcodeSuccess) return resp, nil } // Set NXDomain RCode otherwise. resp.SetRcode(m, dns.RcodeNameError) } return resp, nil } // Address implements upstream.Upstream interface for *TestUpstream. func (u *TestUpstream) Address() string { return u.Addr } // TestBlockUpstream implements upstream.Upstream interface for replacing real // upstream in tests. type TestBlockUpstream struct { Hostname string Block bool requestsCount int lock sync.RWMutex } // Exchange returns a message unique for TestBlockUpstream's Hostname-Block // pair. func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { u.lock.Lock() defer u.lock.Unlock() u.requestsCount++ hash := sha256.Sum256([]byte(u.Hostname)) hashToReturn := hex.EncodeToString(hash[:]) if !u.Block { hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28) } m := &dns.Msg{} m.Answer = []dns.RR{ &dns.TXT{ Hdr: dns.RR_Header{ Name: r.Question[0].Name, }, Txt: []string{ hashToReturn, }, }, } return m, nil } // Address always returns an empty string. func (u *TestBlockUpstream) Address() string { return "" } // RequestsCount returns the number of handled requests. It's safe for // concurrent use. func (u *TestBlockUpstream) RequestsCount() int { u.lock.Lock() defer u.lock.Unlock() return u.requestsCount } // TestErrUpstream implements upstream.Upstream interface for replacing real // upstream in tests. type TestErrUpstream struct{} // Exchange always returns nil Msg and non-nil error. func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { // We don't use an agherr.Error to avoid the import cycle since aghtests // used to provide the utilities for testing which agherr (and any other // testable package) should be able to use. return nil, errors.New("bad") } // Address always returns an empty string. func (u *TestErrUpstream) Address() string { return "" }