AdGuardHome/internal/aghtest/upstream.go

150 lines
3.3 KiB
Go

package aghtest
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"
"github.com/miekg/dns"
)
// Upstream is a mock implementation of upstream.Upstream.
type Upstream struct {
// CName is a map of hostname to canonical name.
CName map[string][]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
// Addr is the address for Address method.
Addr string
}
// Exchange implements the upstream.Upstream interface for *Upstream.
//
// TODO(a.garipov): Split further into handlers.
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = new(dns.Msg).SetReply(m)
if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
q := m.Question[0]
name := q.Name
for _, cname := range u.CName[name] {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME},
Target: cname,
})
}
qtype := q.Qtype
hdr := dns.RR_Header{
Name: name,
Rrtype: qtype,
}
switch qtype {
case dns.TypeA:
for _, ip := range u.IPv4[name] {
resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip})
}
case dns.TypeAAAA:
for _, ip := range u.IPv6[name] {
resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip})
}
case dns.TypePTR:
for _, name := range u.Reverse[name] {
resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name})
}
}
if len(resp.Answer) == 0 {
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *Upstream.
func (u *Upstream) Address() string {
return u.Addr
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
// lock protects reqNum.
lock sync.RWMutex
reqNum int
Block bool
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
m := &dns.Msg{}
m.SetReply(r)
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}
return m, nil
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.reqNum
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct {
// The error returned by Exchange may be unwrapped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
}