diff --git a/internal/home/tls.go b/internal/home/tls.go index 57df2da6..709ada19 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/google/go-cmp/cmp" @@ -30,9 +31,9 @@ var tlsWebHandlersRegistered = false // TLSMod - TLS module object type TLSMod struct { certLastMod time.Time // last modification time of the certificate file - conf tlsConfigSettings - confLock sync.Mutex status tlsConfigStatus + confLock sync.Mutex + conf tlsConfigSettings } // Create TLS module @@ -209,8 +210,8 @@ type tlsConfigStatus struct { // field ordering is important -- yaml fields will mirror ordering from here type tlsConfig struct { - tlsConfigSettings `json:",inline"` tlsConfigStatus `json:",inline"` + tlsConfigSettings `json:",inline"` } func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, _ *http.Request) { @@ -247,6 +248,41 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { marshalTLS(w, data) } +func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) { + t.confLock.Lock() + defer t.confLock.Unlock() + + // Reset the DNSCrypt data before comparing, since we currently do not + // accept these from the frontend. + // + // TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig. + newConf.DNSCryptConfigFile = t.conf.DNSCryptConfigFile + newConf.PortDNSCrypt = t.conf.PortDNSCrypt + if !cmp.Equal(t.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { + log.Info("tls config has changed, restarting https server") + restartHTTPS = true + } else { + log.Info("tls config has not changed") + } + + // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf + t.conf.Enabled = newConf.Enabled + t.conf.ServerName = newConf.ServerName + t.conf.ForceHTTPS = newConf.ForceHTTPS + t.conf.PortHTTPS = newConf.PortHTTPS + t.conf.PortDNSOverTLS = newConf.PortDNSOverTLS + t.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC + t.conf.CertificateChain = newConf.CertificateChain + t.conf.CertificatePath = newConf.CertificatePath + t.conf.CertificateChainData = newConf.CertificateChainData + t.conf.PrivateKey = newConf.PrivateKey + t.conf.PrivateKeyPath = newConf.PrivateKeyPath + t.conf.PrivateKeyData = newConf.PrivateKeyData + t.status = status + + return restartHTTPS +} + func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { data, err := unmarshalTLS(r) if err != nil { @@ -266,42 +302,28 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { tlsConfigStatus: t.status, } marshalTLS(w, data2) + return } - status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) - restartHTTPS := false - t.confLock.Lock() - if !cmp.Equal(t.conf, data) { - log.Printf("tls config settings have changed, will restart HTTPS server") - restartHTTPS = true - } - // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf - t.conf.Enabled = data.Enabled - t.conf.ServerName = data.ServerName - t.conf.ForceHTTPS = data.ForceHTTPS - t.conf.PortHTTPS = data.PortHTTPS - t.conf.PortDNSOverTLS = data.PortDNSOverTLS - t.conf.PortDNSOverQUIC = data.PortDNSOverQUIC - t.conf.CertificateChain = data.CertificateChain - t.conf.CertificatePath = data.CertificatePath - t.conf.CertificateChainData = data.CertificateChainData - t.conf.PrivateKey = data.PrivateKey - t.conf.PrivateKeyPath = data.PrivateKeyPath - t.conf.PrivateKeyData = data.PrivateKeyData - t.status = status - t.confLock.Unlock() + status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) + + restartHTTPS := t.setConfig(data, status) t.setCertFileTime() onConfigModified() + err = reconfigureDNSServer() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) + return } + data2 := tlsConfig{ tlsConfigSettings: data, tlsConfigStatus: t.status, } + marshalTLS(w, data2) if f, ok := w.(http.Flusher); ok { f.Flush()