From 98f6843aab8ff69606f44f6a5721d5f2a5707a0f Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Thu, 29 Oct 2020 19:39:11 +0300 Subject: [PATCH] * home: improve naming in mobileconfig handlers Updates #2235. --- home/home.go | 1 + home/mobileconfig.go | 45 ++++++++++++++++++++++++++++----------- home/mobileconfig_test.go | 33 ++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 12 deletions(-) create mode 100644 home/mobileconfig_test.go diff --git a/home/home.go b/home/home.go index 10235664..09df0d57 100644 --- a/home/home.go +++ b/home/home.go @@ -1,3 +1,4 @@ +// Package home contains AdGuard Home's HTTP API methods. package home import ( diff --git a/home/mobileconfig.go b/home/mobileconfig.go index e828f117..81fe4b1a 100644 --- a/home/mobileconfig.go +++ b/home/mobileconfig.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "net" "net/http" uuid "github.com/satori/go.uuid" @@ -14,7 +15,7 @@ type DNSSettings struct { ServerName string `plist:",omitempty"` } -type PayloadContent = struct { +type PayloadContent struct { Name string PayloadDescription string PayloadDisplayName string @@ -25,7 +26,7 @@ type PayloadContent = struct { DNSSettings DNSSettings } -type MobileConfig = struct { +type MobileConfig struct { PayloadContent []PayloadContent PayloadDescription string PayloadDisplayName string @@ -40,8 +41,21 @@ func genUUIDv4() string { return uuid.NewV4().String() } -func getMobileConfig(r *http.Request, d DNSSettings) ([]byte, error) { - name := fmt.Sprintf("%s DNS over %s", r.Host, d.DNSProtocol) +const ( + dnsProtoHTTPS = "HTTPS" + dnsProtoTLS = "TLS" +) + +func getMobileConfig(d DNSSettings) ([]byte, error) { + var name string + switch d.DNSProtocol { + case dnsProtoHTTPS: + name = fmt.Sprintf("%s DoH", d.ServerName) + case dnsProtoTLS: + name = fmt.Sprintf("%s DoT", d.ServerName) + default: + return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol) + } data := MobileConfig{ PayloadContent: []PayloadContent{{ @@ -66,9 +80,8 @@ func getMobileConfig(r *http.Request, d DNSSettings) ([]byte, error) { return plist.MarshalIndent(data, plist.XMLFormat, "\t") } -func handleMobileConfig(w http.ResponseWriter, r *http.Request, d DNSSettings) { - mobileconfig, err := getMobileConfig(r, d) - +func handleMobileConfig(w http.ResponseWriter, d DNSSettings) { + mobileconfig, err := getMobileConfig(d) if err != nil { httpError(w, http.StatusInternalServerError, "plist.MarshalIndent: %s", err) } @@ -78,15 +91,23 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, d DNSSettings) { } func handleMobileConfigDoh(w http.ResponseWriter, r *http.Request) { - handleMobileConfig(w, r, DNSSettings{ - DNSProtocol: "HTTPS", + handleMobileConfig(w, DNSSettings{ + DNSProtocol: dnsProtoHTTPS, ServerURL: fmt.Sprintf("https://%s/dns-query", r.Host), }) } func handleMobileConfigDot(w http.ResponseWriter, r *http.Request) { - handleMobileConfig(w, r, DNSSettings{ - DNSProtocol: "TLS", - ServerName: r.Host, + var err error + + var host string + host, _, err = net.SplitHostPort(r.Host) + if err != nil { + httpError(w, http.StatusBadRequest, "getting host: %s", err) + } + + handleMobileConfig(w, DNSSettings{ + DNSProtocol: dnsProtoTLS, + ServerName: host, }) } diff --git a/home/mobileconfig_test.go b/home/mobileconfig_test.go new file mode 100644 index 00000000..f58f4e99 --- /dev/null +++ b/home/mobileconfig_test.go @@ -0,0 +1,33 @@ +package home + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "howett.net/plist" +) + +func TestHandleMobileConfigDot(t *testing.T) { + var err error + + var r *http.Request + r, err = http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) + assert.Nil(t, err) + + w := httptest.NewRecorder() + + handleMobileConfigDot(w, r) + assert.Equal(t, http.StatusOK, w.Code) + + var mc MobileConfig + _, err = plist.Unmarshal(w.Body.Bytes(), &mc) + assert.Nil(t, err) + + if assert.Equal(t, 1, len(mc.PayloadContent)) { + assert.Equal(t, "example.com DoT", mc.PayloadContent[0].Name) + assert.Equal(t, "example.com DoT", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.com", mc.PayloadContent[0].DNSSettings.ServerName) + } +}