diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 356b835f..beacbabf 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -233,6 +233,13 @@ func (s *Server) startInternal() error { func (s *Server) Prepare(config *ServerConfig) error { if config != nil { s.conf = *config + if s.conf.BlockingMode == "custom_ip" { + s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4) + s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6) + if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil { + return fmt.Errorf("DNS: invalid custom blocking IP address specified") + } + } } if len(s.conf.UpstreamDNS) == 0 { diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 155d6f95..88f7fb78 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -424,6 +424,55 @@ func TestNullBlockedRequest(t *testing.T) { } } +func TestBlockedCustomIP(t *testing.T) { + rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" + filters := map[int]string{} + filters[0] = rules + c := dnsfilter.Config{} + + f := dnsfilter.New(&c, filters) + s := NewServer(f, nil, nil) + conf := ServerConfig{} + conf.UDPListenAddr = &net.UDPAddr{Port: 0} + conf.TCPListenAddr = &net.TCPAddr{Port: 0} + conf.ProtectionEnabled = true + conf.BlockingMode = "custom_ip" + conf.BlockingIPv4 = "bad IP" + conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} + err := s.Prepare(&conf) + assert.True(t, err != nil) // invalid BlockingIPv4 + + conf.BlockingIPv4 = "0.0.0.1" + conf.BlockingIPv6 = "::1" + err = s.Prepare(&conf) + assert.True(t, err == nil) + err = s.Start() + assert.True(t, err == nil, "%s", err) + + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + req := createTestMessageWithType("null.example.org.", dns.TypeA) + reply, err := dns.Exchange(req, addr.String()) + assert.True(t, err == nil) + assert.True(t, len(reply.Answer) == 1) + a, ok := reply.Answer[0].(*dns.A) + assert.True(t, ok) + assert.True(t, a.A.String() == "0.0.0.1") + + req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) + reply, err = dns.Exchange(req, addr.String()) + assert.True(t, err == nil) + assert.True(t, len(reply.Answer) == 1) + a6, ok := reply.Answer[0].(*dns.AAAA) + assert.True(t, ok) + assert.True(t, a6.AAAA.String() == "::1") + + err = s.Stop() + if err != nil { + t.Fatalf("DNS server failed to stop: %s", err) + } +} + func TestBlockedByHosts(t *testing.T) { s := createTestServer(t) err := s.Start() @@ -652,6 +701,16 @@ func createTestMessage(host string) *dns.Msg { return &req } +func createTestMessageWithType(host string, qtype uint16) *dns.Msg { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: host, Qtype: qtype, Qclass: dns.ClassINET}, + } + return &req +} + func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { assertResponse(t, reply, "8.8.8.8") } diff --git a/home/dns.go b/home/dns.go index d6fc8b8a..1a0666fb 100644 --- a/home/dns.go +++ b/home/dns.go @@ -70,6 +70,9 @@ func initDNSServer() error { sessFilename := filepath.Join(baseDir, "sessions.db") config.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) + if config.auth == nil { + return fmt.Errorf("Couldn't initialize Auth module") + } config.Users = nil Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) @@ -254,6 +257,10 @@ func reconfigureDNSServer() error { } func stopDNSServer() error { + if !isRunning() { + return nil + } + err := Context.dnsServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop forwarding DNS server")