Merge: + DNS: Allow DOH queries via unencrypted HTTP

Close #1276

* commit '91c3149ee2dc902a5081345431f586ae72362963':
  + allow_unencrypted_doh: add test
  + DNS: Allow DOH queries via unencrypted HTTP (e.g. for reverse proxying)
This commit is contained in:
Simon Zolin 2019-12-20 15:14:41 +03:00
commit ceab5d4c41
3 changed files with 31 additions and 1 deletions

View File

@ -117,6 +117,9 @@ type tlsConfigSettings struct {
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DOT will be disabled
// Allow DOH queries via unencrypted HTTP (e.g. for reverse proxying)
AllowUnencryptedDOH bool `yaml:"allow_unencrypted_doh" json:"allow_unencrypted_doh"`
dnsforward.TLSConfig `yaml:",inline" json:",inline"` dnsforward.TLSConfig `yaml:",inline" json:",inline"`
} }

View File

@ -144,7 +144,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
// DNS-over-HTTPS // DNS-over-HTTPS
// -------------- // --------------
func handleDOH(w http.ResponseWriter, r *http.Request) { func handleDOH(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil { if !config.TLS.AllowUnencryptedDOH && r.TLS == nil {
httpError(w, http.StatusNotFound, "Not Found") httpError(w, http.StatusNotFound, "Not Found")
return return
} }

View File

@ -2,6 +2,7 @@ package home
import ( import (
"context" "context"
"encoding/base64"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
@ -9,7 +10,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -61,6 +64,7 @@ tls:
force_https: false force_https: false
port_https: 443 port_https: 443
port_dns_over_tls: 853 port_dns_over_tls: 853
allow_unencrypted_doh: true
certificate_chain: "" certificate_chain: ""
private_key: "" private_key: ""
certificate_path: "" certificate_path: ""
@ -99,6 +103,7 @@ schema_version: 5
// . Start AGH instance // . Start AGH instance
// . Check Web server // . Check Web server
// . Check DNS server // . Check DNS server
// . Check DNS server with DOH
// . Wait until the filters are downloaded // . Wait until the filters are downloaded
// . Stop and cleanup // . Stop and cleanup
func TestHome(t *testing.T) { func TestHome(t *testing.T) {
@ -131,12 +136,34 @@ func TestHome(t *testing.T) {
assert.Truef(t, err == nil, "%s", err) assert.Truef(t, err == nil, "%s", err)
assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, 200, resp.StatusCode)
// test DNS over UDP
r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second) r := upstream.NewResolver("127.0.0.1:5354", 3*time.Second)
addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com") addrs, err := r.LookupIPAddr(context.TODO(), "static.adguard.com")
assert.Truef(t, err == nil, "%s", err) assert.Truef(t, err == nil, "%s", err)
haveIP := len(addrs) != 0 haveIP := len(addrs) != 0
assert.True(t, haveIP) assert.True(t, haveIP)
// test DNS over HTTP without encryption
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
buf, err := req.Pack()
assert.True(t, err == nil, "%s", err)
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
resp, err = http.DefaultClient.Get(requestURL)
assert.True(t, err == nil, "%s", err)
body, err := ioutil.ReadAll(resp.Body)
assert.True(t, err == nil, "%s", err)
assert.True(t, resp.StatusCode == http.StatusOK)
response := dns.Msg{}
err = response.Unpack(body)
assert.True(t, err == nil, "%s", err)
addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0
assert.True(t, haveIP)
for i := 1; ; i++ { for i := 1; ; i++ {
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt")) st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
if err == nil && st.Size() != 0 { if err == nil && st.Size() != 0 {