dnsforward -- implement ratelimit and refuseany

This commit is contained in:
Eugene Bujak 2018-12-05 18:47:03 +03:00
parent 15f0dee719
commit 478ce03386
4 changed files with 187 additions and 28 deletions

View File

@ -46,14 +46,11 @@ type coreDNSConfig struct {
dnsforward.FilteringConfig `yaml:",inline"` dnsforward.FilteringConfig `yaml:",inline"`
QueryLogEnabled bool `yaml:"querylog_enabled"` Pprof string `yaml:"-"`
Ratelimit int `yaml:"ratelimit"` Cache string `yaml:"-"`
RefuseAny bool `yaml:"refuse_any"` Prometheus string `yaml:"-"`
Pprof string `yaml:"-"` BootstrapDNS string `yaml:"bootstrap_dns"`
Cache string `yaml:"-"` UpstreamDNS []string `yaml:"upstream_dns"`
Prometheus string `yaml:"-"`
BootstrapDNS string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"`
} }
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
@ -71,14 +68,14 @@ var config = configuration{
ProtectionEnabled: true, // whether or not use any of dnsfilter features ProtectionEnabled: true, // whether or not use any of dnsfilter features
FilteringEnabled: true, // whether or not use filter lists FilteringEnabled: true, // whether or not use filter lists
BlockedResponseTTL: 10, // in seconds BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
Ratelimit: 20,
RefuseAny: true,
}, },
QueryLogEnabled: true, BootstrapDNS: "8.8.8.8:53",
Ratelimit: 20, UpstreamDNS: defaultDNS,
RefuseAny: true, Cache: "cache",
BootstrapDNS: "8.8.8.8:53", Prometheus: "prometheus :9153",
UpstreamDNS: defaultDNS,
Cache: "cache",
Prometheus: "prometheus :9153",
}, },
Filters: []filter{ Filters: []filter{
{Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"},

View File

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/joomcode/errorx" "github.com/joomcode/errorx"
"github.com/miekg/dns" "github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
) )
// Server is the main way to start a DNS server. // Server is the main way to start a DNS server.
@ -31,6 +32,8 @@ type Server struct {
cache cache cache cache
ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP
sync.RWMutex sync.RWMutex
ServerConfig ServerConfig
} }
@ -76,9 +79,13 @@ func (s *Server) RUnlock() {
*/ */
type FilteringConfig struct { type FilteringConfig struct {
ProtectionEnabled bool `yaml:"protection_enabled"` ProtectionEnabled bool `yaml:"protection_enabled"`
FilteringEnabled bool `yaml:"filtering_enabled"` FilteringEnabled bool `yaml:"filtering_enabled"`
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
QueryLogEnabled bool `yaml:"querylog_enabled"`
Ratelimit int `yaml:"ratelimit"`
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
RefuseAny bool `yaml:"refuse_any"`
dnsfilter.Config `yaml:",inline"` dnsfilter.Config `yaml:",inline"`
} }
@ -92,6 +99,7 @@ type ServerConfig struct {
FilteringConfig FilteringConfig
} }
// if any of ServerConfig values are zero, then default values from below are used
var defaultValues = ServerConfig{ var defaultValues = ServerConfig{
UDPListenAddr: &net.UDPAddr{Port: 53}, UDPListenAddr: &net.UDPAddr{Port: 53},
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
@ -413,6 +421,10 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
return s.genServerFailure(msg), nil, nil, nil return s.genServerFailure(msg), nil, nil, nil
} }
if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny {
return s.genNotImpl(msg), nil, nil, nil
}
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
host := strings.TrimSuffix(msg.Question[0].Name, ".") host := strings.TrimSuffix(msg.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host) res, err := s.dnsFilter.CheckHost(host)
@ -450,16 +462,36 @@ func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDP
func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
start := time.Now() start := time.Now()
ip, _, err := net.SplitHostPort(addr.String())
if err != nil {
log.Printf("Failed to split %v into host/port: %s", addr, err)
// not a fatal error, move on
}
// ratelimit based on IP only, protects CPU cycles and outbound connections
if s.isRatelimited(ip) {
// log.Printf("Ratelimiting %s based on IP only", ip)
return // do nothing, don't reply, we got ratelimited
}
msg := &dns.Msg{} msg := &dns.Msg{}
err := msg.Unpack(p) err = msg.Unpack(p)
if err != nil { if err != nil {
log.Printf("got invalid DNS packet: %s", err) log.Printf("got invalid DNS packet: %s", err)
return // do nothing return // do nothing
} }
reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn) reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn)
if reply != nil { if reply != nil {
// ratelimit based on reply size now
replysize := reply.Len()
if s.isRatelimitedForReply(ip, replysize) {
log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize)
return // do nothing, don't reply, we got ratelimited
}
// we're good to respond
rerr := s.respond(reply, addr, conn) rerr := s.respond(reply, addr, conn)
if rerr != nil { if rerr != nil {
log.Printf("Couldn't respond to UDP packet: %s", err) log.Printf("Couldn't respond to UDP packet: %s", err)
@ -467,16 +499,14 @@ func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) {
} }
// query logging and stats counters // query logging and stats counters
elapsed := time.Since(start) if s.QueryLogEnabled {
upstreamAddr := "" elapsed := time.Since(start)
if upstream != nil { upstreamAddr := ""
upstreamAddr = upstream.Address() if upstream != nil {
upstreamAddr = upstream.Address()
}
logRequest(msg, reply, result, elapsed, ip, upstreamAddr)
} }
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
log.Printf("Failed to split %v into host/port: %s", addr, err)
}
logRequest(msg, reply, result, elapsed, host, upstreamAddr)
} }
// //
@ -506,12 +536,22 @@ func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{} resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure) resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNotImplemented)
resp.RecursionAvailable = true
resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it
return &resp return &resp
} }
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{} resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError) resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request) resp.Ns = s.genSOA(request)
return &resp return &resp
} }

80
dnsforward/ratelimit.go Normal file
View File

@ -0,0 +1,80 @@
package dnsforward
import (
"log"
"sort"
"time"
"github.com/beefsack/go-rate"
gocache "github.com/patrickmn/go-cache"
)
func (s *Server) limiterForIP(ip string) interface{} {
if s.ratelimitBuckets == nil {
s.ratelimitBuckets = gocache.New(time.Hour, time.Hour)
}
// check if ratelimiter for that IP already exists, if not, create
value, found := s.ratelimitBuckets.Get(ip)
if !found {
value = rate.New(s.Ratelimit, time.Second)
s.ratelimitBuckets.Set(ip, value, time.Hour)
}
return value
}
func (s *Server) isRatelimited(ip string) bool {
if s.Ratelimit == 0 { // 0 -- disabled
return false
}
if len(s.RatelimitWhitelist) > 0 {
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
// found, don't ratelimit
return false
}
}
value := s.limiterForIP(ip)
rl, ok := value.(*rate.RateLimiter)
if !ok {
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
return false
}
allow, _ := rl.Try()
return !allow
}
func (s *Server) isRatelimitedForReply(ip string, size int) bool {
if s.Ratelimit == 0 { // 0 -- disabled
return false
}
if len(s.RatelimitWhitelist) > 0 {
i := sort.SearchStrings(s.RatelimitWhitelist, ip)
if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip {
// found, don't ratelimit
return false
}
}
value := s.limiterForIP(ip)
rl, ok := value.(*rate.RateLimiter)
if !ok {
log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache")
return false
}
// For large UDP responses we try more times, effectively limiting per bandwidth
// The exact number of times depends on the response size
for i := 0; i < size/1000; i++ {
allow, _ := rl.Try()
if !allow { // not allowed -> ratelimited
return true
}
}
return false
}

View File

@ -0,0 +1,42 @@
package dnsforward
import (
"testing"
)
func TestRatelimiting(t *testing.T) {
// rate limit is 1 per sec
p := Server{}
p.Ratelimit = 1
limited := p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited("127.0.0.1")
if !limited {
t.Fatal("Second request must have been ratelimited")
}
}
func TestWhitelist(t *testing.T) {
// rate limit is 1 per sec with whitelist
p := Server{}
p.Ratelimit = 1
p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"}
limited := p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("First request must have been allowed")
}
limited = p.isRatelimited("127.0.0.1")
if limited {
t.Fatal("Second request must have been allowed due to whitelist")
}
}