* dnsforward: use separate ServerConfig object

This commit is contained in:
Simon Zolin 2019-05-15 15:08:15 +03:00
parent 36e273dfd5
commit 9644f79a03
2 changed files with 33 additions and 33 deletions

View File

@ -44,7 +44,7 @@ type Server struct {
once sync.Once once sync.Once
sync.RWMutex sync.RWMutex
ServerConfig conf ServerConfig
} }
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
@ -123,7 +123,7 @@ func (s *Server) Start(config *ServerConfig) error {
// startInternal starts without locking // startInternal starts without locking
func (s *Server) startInternal(config *ServerConfig) error { func (s *Server) startInternal(config *ServerConfig) error {
if config != nil { if config != nil {
s.ServerConfig = *config s.conf = *config
} }
if s.dnsFilter != nil || s.dnsProxy != nil { if s.dnsFilter != nil || s.dnsProxy != nil {
@ -158,21 +158,21 @@ func (s *Server) startInternal(config *ServerConfig) error {
}) })
proxyConfig := proxy.Config{ proxyConfig := proxy.Config{
UDPListenAddr: s.UDPListenAddr, UDPListenAddr: s.conf.UDPListenAddr,
TCPListenAddr: s.TCPListenAddr, TCPListenAddr: s.conf.TCPListenAddr,
Ratelimit: s.Ratelimit, Ratelimit: s.conf.Ratelimit,
RatelimitWhitelist: s.RatelimitWhitelist, RatelimitWhitelist: s.conf.RatelimitWhitelist,
RefuseAny: s.RefuseAny, RefuseAny: s.conf.RefuseAny,
CacheEnabled: true, CacheEnabled: true,
Upstreams: s.Upstreams, Upstreams: s.conf.Upstreams,
DomainsReservedUpstreams: s.DomainsReservedUpstreams, DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams,
Handler: s.handleDNSRequest, Handler: s.handleDNSRequest,
AllServers: s.AllServers, AllServers: s.conf.AllServers,
} }
if s.TLSListenAddr != nil && s.CertificateChain != "" && s.PrivateKey != "" { if s.conf.TLSListenAddr != nil && s.conf.CertificateChain != "" && s.conf.PrivateKey != "" {
proxyConfig.TLSListenAddr = s.TLSListenAddr proxyConfig.TLSListenAddr = s.conf.TLSListenAddr
keypair, err := tls.X509KeyPair([]byte(s.CertificateChain), []byte(s.PrivateKey)) keypair, err := tls.X509KeyPair([]byte(s.conf.CertificateChain), []byte(s.conf.PrivateKey))
if err != nil { if err != nil {
return errorx.Decorate(err, "Failed to parse TLS keypair") return errorx.Decorate(err, "Failed to parse TLS keypair")
} }
@ -202,10 +202,10 @@ func (s *Server) startInternal(config *ServerConfig) error {
// Initializes the DNS filter // Initializes the DNS filter
func (s *Server) initDNSFilter() error { func (s *Server) initDNSFilter() error {
log.Tracef("Creating dnsfilter") log.Tracef("Creating dnsfilter")
s.dnsFilter = dnsfilter.New(&s.Config) s.dnsFilter = dnsfilter.New(&s.conf.Config)
// add rules only if they are enabled // add rules only if they are enabled
if s.FilteringEnabled { if s.conf.FilteringEnabled {
err := s.dnsFilter.AddRules(s.Filters) err := s.dnsFilter.AddRules(s.conf.Filters)
if err != nil { if err != nil {
return errorx.Decorate(err, "could not initialize dnsfilter") return errorx.Decorate(err, "could not initialize dnsfilter")
} }
@ -336,11 +336,11 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
msg := d.Req msg := d.Req
// don't log ANY request if refuseAny is enabled // don't log ANY request if refuseAny is enabled
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny { if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
shouldLog = false shouldLog = false
} }
if s.QueryLogEnabled && shouldLog { if s.conf.QueryLogEnabled && shouldLog {
elapsed := time.Since(start) elapsed := time.Since(start)
upstreamAddr := "" upstreamAddr := ""
if d.Upstream != nil { if d.Upstream != nil {
@ -361,7 +361,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
host := strings.TrimSuffix(msg.Question[0].Name, ".") host := strings.TrimSuffix(msg.Question[0].Name, ".")
s.RLock() s.RLock()
protectionEnabled := s.ProtectionEnabled protectionEnabled := s.conf.ProtectionEnabled
dnsFilter := s.dnsFilter dnsFilter := s.dnsFilter
s.RUnlock() s.RUnlock()
@ -402,7 +402,7 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
return s.genARecord(m, result.IP) return s.genARecord(m, result.IP)
} }
if s.BlockingMode == "null_ip" { if s.conf.BlockingMode == "null_ip" {
return s.genARecord(m, net.IPv4zero) return s.genARecord(m, net.IPv4zero)
} }
@ -420,7 +420,7 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
resp := dns.Msg{} resp := dns.Msg{}
resp.SetReply(request) resp.SetReply(request)
answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.BlockedResponseTTL, ip.String())) answer, err := dns.NewRR(fmt.Sprintf("%s %d A %s", request.Question[0].Name, s.conf.BlockedResponseTTL, ip.String()))
if err != nil { if err != nil {
log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err) log.Printf("Couldn't generate A record for replacement host '%s': %s", ip.String(), err)
return s.genServerFailure(request) return s.genServerFailure(request)
@ -489,7 +489,7 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: zone, Name: zone,
Rrtype: dns.TypeSOA, Rrtype: dns.TypeSOA,
Ttl: s.BlockedResponseTTL, Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET, Class: dns.ClassINET,
}, },
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."

View File

@ -86,7 +86,7 @@ func TestDotServer(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
defer removeDataDir(t) defer removeDataDir(t)
s.TLSConfig = TLSConfig{ s.conf.TLSConfig = TLSConfig{
TLSListenAddr: &net.TCPAddr{Port: 0}, TLSListenAddr: &net.TCPAddr{Port: 0},
CertificateChain: string(certPem), CertificateChain: string(certPem),
PrivateKey: string(keyPem), PrivateKey: string(keyPem),
@ -149,7 +149,7 @@ func TestServerRace(t *testing.T) {
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
s.SafeSearchEnabled = true s.conf.SafeSearchEnabled = true
defer removeDataDir(t) defer removeDataDir(t)
err := s.Start(nil) err := s.Start(nil)
if err != nil { if err != nil {
@ -295,7 +295,7 @@ func TestBlockedRequest(t *testing.T) {
func TestNullBlockedRequest(t *testing.T) { func TestNullBlockedRequest(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
s.FilteringConfig.BlockingMode = "null_ip" s.conf.FilteringConfig.BlockingMode = "null_ip"
defer removeDataDir(t) defer removeDataDir(t)
err := s.Start(nil) err := s.Start(nil)
if err != nil { if err != nil {
@ -451,14 +451,14 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
func createTestServer(t *testing.T) *Server { func createTestServer(t *testing.T) *Server {
s := NewServer(createDataDir(t)) s := NewServer(createDataDir(t))
s.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.QueryLogEnabled = true s.conf.QueryLogEnabled = true
s.FilteringConfig.FilteringEnabled = true s.conf.FilteringConfig.FilteringEnabled = true
s.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
s.FilteringConfig.SafeBrowsingEnabled = true s.conf.FilteringConfig.SafeBrowsingEnabled = true
s.Filters = make([]dnsfilter.Filter, 0) s.conf.Filters = make([]dnsfilter.Filter, 0)
rules := []string{ rules := []string{
"||nxdomain.example.org^", "||nxdomain.example.org^",
@ -466,7 +466,7 @@ func createTestServer(t *testing.T) *Server {
"127.0.0.1 host.example.org", "127.0.0.1 host.example.org",
} }
filter := dnsfilter.Filter{ID: 1, Rules: rules} filter := dnsfilter.Filter{ID: 1, Rules: rules}
s.Filters = append(s.Filters, filter) s.conf.Filters = append(s.conf.Filters, filter)
return s return s
} }