Pull request: 4925-refactor-tls-vol-2

Updates #4925.

Squashed commit of the following:

commit 4b221936ea6c2a244c404e95fa2a033571e07168
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Oct 14 19:03:42 2022 +0300

    all: refactor tls
This commit is contained in:
Ainar Garipov 2022-10-14 19:37:14 +03:00
parent a1acfbbae4
commit fee81b31ec
5 changed files with 399 additions and 319 deletions

View File

@ -424,7 +424,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
// moment we'll allow setting up TLS in the initial configuration or the // moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying // configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server. // functions potentially restart the HTTPS server.
err = StartMods() err = startMods()
if err != nil { if err != nil {
Context.firstRun = true Context.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)

View File

@ -59,7 +59,7 @@ type homeContext struct {
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module tls *tlsManager // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration // etcHosts is an IP-hostname pairs set taken from system configuration
// (e.g. /etc/hosts) files. // (e.g. /etc/hosts) files.
etcHosts *aghnet.HostsContainer etcHosts *aghnet.HostsContainer
@ -117,7 +117,7 @@ func Main(clientBuildFS fs.FS) {
switch sig { switch sig {
case syscall.SIGHUP: case syscall.SIGHUP:
Context.clients.Reload() Context.clients.Reload()
Context.tls.Reload() Context.tls.reload()
default: default:
cleanup(context.Background()) cleanup(context.Background())
@ -495,9 +495,9 @@ func run(opts options, clientBuildFS fs.FS) {
} }
config.Users = nil config.Users = nil
Context.tls = tlsCreate(config.TLS) Context.tls, err = newTLSManager(config.TLS)
if Context.tls == nil { if err != nil {
log.Fatalf("Can't initialize TLS module") log.Fatalf("initializing tls: %s", err)
} }
Context.web, err = initWeb(opts, clientBuildFS) Context.web, err = initWeb(opts, clientBuildFS)
@ -507,7 +507,7 @@ func run(opts options, clientBuildFS fs.FS) {
err = initDNSServer() err = initDNSServer()
fatalOnError(err) fatalOnError(err)
Context.tls.Start() Context.tls.start()
go func() { go func() {
serr := startDNSServer() serr := startDNSServer()
@ -531,20 +531,22 @@ func run(opts options, clientBuildFS fs.FS) {
select {} select {}
} }
// StartMods initializes and starts the DNS server after installation. // startMods initializes and starts the DNS server after installation.
func StartMods() error { func startMods() error {
err := initDNSServer() err := initDNSServer()
if err != nil { if err != nil {
return err return err
} }
Context.tls.Start() Context.tls.start()
err = startDNSServer() err = startDNSServer()
if err != nil { if err != nil {
closeDNSServer() closeDNSServer()
return err return err
} }
return nil return nil
} }
@ -728,7 +730,6 @@ func cleanup(ctx context.Context) {
} }
if Context.tls != nil { if Context.tls != nil {
Context.tls.Close()
Context.tls = nil Context.tls = nil
} }
} }
@ -738,7 +739,8 @@ func cleanupAlways() {
if len(Context.pidFileName) != 0 { if len(Context.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName) _ = os.Remove(Context.pidFileName)
} }
log.Info("Stopped")
log.Info("stopped")
} }
func exitWithError() { func exitWithError() {

View File

@ -32,7 +32,7 @@ func setupDNSIPs(t testing.TB) {
}, },
} }
Context.tls = &TLSMod{} Context.tls = &tlsManager{}
} }
func TestHandleMobileConfigDoH(t *testing.T) { func TestHandleMobileConfigDoH(t *testing.T) {
@ -65,7 +65,7 @@ func TestHandleMobileConfigDoH(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &TLSMod{conf: tlsConfigSettings{}} Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)
@ -137,7 +137,7 @@ func TestHandleMobileConfigDoT(t *testing.T) {
oldTLSConf := Context.tls oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf }) t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &TLSMod{conf: tlsConfigSettings{}} Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil) r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err) require.NoError(t, err)

View File

@ -26,216 +26,256 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
var tlsWebHandlersRegistered = false // tlsManager contains the current configuration and state of AdGuard Home TLS
// encryption.
type tlsManager struct {
// status is the current status of the configuration. It is never nil.
status *tlsConfigStatus
// TLSMod - TLS module object // certLastMod is the last modification time of the certificate file.
type TLSMod struct { certLastMod time.Time
certLastMod time.Time // last modification time of the certificate file
status tlsConfigStatus confLock sync.Mutex
confLock sync.Mutex conf tlsConfigSettings
conf tlsConfigSettings
} }
// Create TLS module // newTLSManager initializes the TLS configuration.
func tlsCreate(conf tlsConfigSettings) *TLSMod { func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
t := &TLSMod{} m = &tlsManager{
t.conf = conf status: &tlsConfigStatus{},
if t.conf.Enabled { conf: conf,
if !t.load() { }
// Something is not valid - return an empty TLS config
return &TLSMod{conf: tlsConfigSettings{ if m.conf.Enabled {
Enabled: conf.Enabled, err = m.load()
ServerName: conf.ServerName, if err != nil {
PortHTTPS: conf.PortHTTPS, return nil, err
PortDNSOverTLS: conf.PortDNSOverTLS,
PortDNSOverQUIC: conf.PortDNSOverQUIC,
AllowUnencryptedDoH: conf.AllowUnencryptedDoH,
}}
} }
t.setCertFileTime()
m.setCertFileTime()
} }
return t
return m, nil
} }
func (t *TLSMod) load() bool { // load reloads the TLS configuration from files or data from the config file.
if !tlsLoadConfig(&t.conf, &t.status) { func (m *tlsManager) load() (err error) {
log.Error("failed to load TLS config: %s", t.status.WarningValidation) err = loadTLSConf(&m.conf, m.status)
return false if err != nil {
return fmt.Errorf("loading config: %w", err)
} }
// validate current TLS config and update warnings (it could have been loaded from file) return nil
data := validateCertificates(string(t.conf.CertificateChainData), string(t.conf.PrivateKeyData), t.conf.ServerName)
if !data.ValidPair {
log.Error("failed to validate certificate: %s", data.WarningValidation)
return false
}
t.status = data
return true
}
// Close - close module
func (t *TLSMod) Close() {
} }
// WriteDiskConfig - write config // WriteDiskConfig - write config
func (t *TLSMod) WriteDiskConfig(conf *tlsConfigSettings) { func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
t.confLock.Lock() m.confLock.Lock()
*conf = t.conf *conf = m.conf
t.confLock.Unlock() m.confLock.Unlock()
} }
func (t *TLSMod) setCertFileTime() { // setCertFileTime sets t.certLastMod from the certificate. If there are
if len(t.conf.CertificatePath) == 0 { // errors, setCertFileTime logs them.
func (m *tlsManager) setCertFileTime() {
if len(m.conf.CertificatePath) == 0 {
return return
} }
fi, err := os.Stat(t.conf.CertificatePath)
fi, err := os.Stat(m.conf.CertificatePath)
if err != nil { if err != nil {
log.Error("TLS: %s", err) log.Error("tls: looking up certificate path: %s", err)
return return
} }
t.certLastMod = fi.ModTime().UTC()
m.certLastMod = fi.ModTime().UTC()
} }
// Start updates the configuration of TLSMod and starts it. // start updates the configuration of t and starts it.
func (t *TLSMod) Start() { func (m *tlsManager) start() {
if !tlsWebHandlersRegistered { m.registerWebHandlers()
tlsWebHandlersRegistered = true
t.registerWebHandlers()
}
t.confLock.Lock() m.confLock.Lock()
tlsConf := t.conf tlsConf := m.conf
t.confLock.Unlock() m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps // The background context is used because the TLSConfigChanged wraps context
// context with timeout on its own and shuts down the server, which // with timeout on its own and shuts down the server, which handles current
// handles current request. // request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.TLSConfigChanged(context.Background(), tlsConf)
} }
// Reload updates the configuration of TLSMod and restarts it. // reload updates the configuration and restarts t.
func (t *TLSMod) Reload() { func (m *tlsManager) reload() {
t.confLock.Lock() m.confLock.Lock()
tlsConf := t.conf tlsConf := m.conf
t.confLock.Unlock() m.confLock.Unlock()
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 { if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
return return
} }
fi, err := os.Stat(tlsConf.CertificatePath) fi, err := os.Stat(tlsConf.CertificatePath)
if err != nil { if err != nil {
log.Error("TLS: %s", err) log.Error("tls: %s", err)
return
}
if fi.ModTime().UTC().Equal(t.certLastMod) {
log.Debug("TLS: certificate file isn't modified")
return
}
log.Debug("TLS: certificate file is modified")
t.confLock.Lock()
r := t.load()
t.confLock.Unlock()
if !r {
return return
} }
t.certLastMod = fi.ModTime().UTC() if fi.ModTime().UTC().Equal(m.certLastMod) {
log.Debug("tls: certificate file isn't modified")
return
}
log.Debug("tls: certificate file is modified")
m.confLock.Lock()
err = m.load()
m.confLock.Unlock()
if err != nil {
log.Error("tls: reloading: %s", err)
return
}
m.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer() _ = reconfigureDNSServer()
t.confLock.Lock() m.confLock.Lock()
tlsConf = t.conf tlsConf = m.conf
t.confLock.Unlock() m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which // The background context is used because the TLSConfigChanged wraps context
// handles current request. // with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.TLSConfigChanged(context.Background(), tlsConf)
} }
// Set certificate and private key data // loadTLSConf loads and validates the TLS configuration. The returned error is
func tlsLoadConfig(tls *tlsConfigSettings, status *tlsConfigStatus) bool { // also set in status.WarningValidation.
tls.CertificateChainData = []byte(tls.CertificateChain) func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
tls.PrivateKeyData = []byte(tls.PrivateKey) defer func() {
var err error
if tls.CertificatePath != "" {
if tls.CertificateChain != "" {
status.WarningValidation = "certificate data and file can't be set together"
return false
}
tls.CertificateChainData, err = os.ReadFile(tls.CertificatePath)
if err != nil { if err != nil {
status.WarningValidation = err.Error() status.WarningValidation = err.Error()
return false
} }
}()
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together")
}
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
if err != nil {
return fmt.Errorf("reading cert file: %w", err)
}
status.ValidCert = true status.ValidCert = true
} }
if tls.PrivateKeyPath != "" { if tlsConf.PrivateKeyPath != "" {
if tls.PrivateKey != "" { if tlsConf.PrivateKey != "" {
status.WarningValidation = "private key data and file can't be set together" return errors.Error("private key data and file can't be set together")
return false
} }
tls.PrivateKeyData, err = os.ReadFile(tls.PrivateKeyPath)
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
if err != nil { if err != nil {
status.WarningValidation = err.Error() return fmt.Errorf("reading key file: %w", err)
return false
} }
status.ValidKey = true status.ValidKey = true
} }
return true err = validateCertificates(
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
)
if err != nil {
return fmt.Errorf("validating certificate pair: %w", err)
}
return nil
} }
// tlsConfigStatus contains the status of a certificate chain and key pair.
type tlsConfigStatus struct { type tlsConfigStatus struct {
ValidCert bool `json:"valid_cert"` // ValidCert is true if the specified certificates chain is a valid chain of X509 certificates // Subject is the subject of the first certificate in the chain.
ValidChain bool `json:"valid_chain"` // ValidChain is true if the specified certificates chain is verified and issued by a known CA Subject string `json:"subject,omitempty"`
Subject string `json:"subject,omitempty"` // Subject is the subject of the first certificate in the chain
Issuer string `json:"issuer,omitempty"` // Issuer is the issuer of the first certificate in the chain
NotBefore time.Time `json:"not_before,omitempty"` // NotBefore is the NotBefore field of the first certificate in the chain
NotAfter time.Time `json:"not_after,omitempty"` // NotAfter is the NotAfter field of the first certificate in the chain
DNSNames []string `json:"dns_names"` // DNSNames is the value of SubjectAltNames field of the first certificate in the chain
// key status // Issuer is the issuer of the first certificate in the chain.
ValidKey bool `json:"valid_key"` // ValidKey is true if the key is a valid private key Issuer string `json:"issuer,omitempty"`
KeyType string `json:"key_type,omitempty"` // KeyType is one of RSA or ECDSA
// is usable? set by validator // KeyType is the type of the private key.
ValidPair bool `json:"valid_pair"` // ValidPair is true if both certificate and private key are correct KeyType string `json:"key_type,omitempty"`
// warnings // NotBefore is the NotBefore field of the first certificate in the chain.
WarningValidation string `json:"warning_validation,omitempty"` // WarningValidation is a validation warning message with the issue description NotBefore time.Time `json:"not_before,omitempty"`
// NotAfter is the NotAfter field of the first certificate in the chain.
NotAfter time.Time `json:"not_after,omitempty"`
// WarningValidation is a validation warning message with the issue
// description.
WarningValidation string `json:"warning_validation,omitempty"`
// DNSNames is the value of SubjectAltNames field of the first certificate
// in the chain.
DNSNames []string `json:"dns_names"`
// ValidCert is true if the specified certificate chain is a valid chain of
// X509 certificates.
ValidCert bool `json:"valid_cert"`
// ValidChain is true if the specified certificate chain is verified and
// issued by a known CA.
ValidChain bool `json:"valid_chain"`
// ValidKey is true if the key is a valid private key.
ValidKey bool `json:"valid_key"`
// ValidPair is true if both certificate and private key are correct for
// each other.
ValidPair bool `json:"valid_pair"`
} }
// field ordering is important -- yaml fields will mirror ordering from here // tlsConfig is the TLS configuration and status response.
type tlsConfig struct { type tlsConfig struct {
tlsConfigStatus `json:",inline"` *tlsConfigStatus `json:",inline"`
tlsConfigSettingsExt `json:",inline"` tlsConfigSettingsExt `json:",inline"`
} }
// tlsConfigSettingsExt is used to (un)marshal PrivateKeySaved to ensure that // tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
// clients don't send and receive previously saved private keys. // ensure that clients don't send and receive previously saved private keys.
type tlsConfigSettingsExt struct { type tlsConfigSettingsExt struct {
tlsConfigSettings `json:",inline"` tlsConfigSettings `json:",inline"`
// If private key saved as a string, we set this flag to true
// and omit key from answer. // PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"` PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
} }
func (t *TLSMod) handleTLSStatus(w http.ResponseWriter, r *http.Request) { func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
t.confLock.Lock() m.confLock.Lock()
data := tlsConfig{ data := tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{ tlsConfigSettingsExt: tlsConfigSettingsExt{
tlsConfigSettings: t.conf, tlsConfigSettings: m.conf,
}, },
tlsConfigStatus: t.status, tlsConfigStatus: m.status,
} }
t.confLock.Unlock() m.confLock.Unlock()
marshalTLS(w, r, data) marshalTLS(w, r, data)
} }
func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) { func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r) setts, err := unmarshalTLS(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
@ -244,7 +284,7 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
} }
if setts.PrivateKeySaved { if setts.PrivateKeySaved {
setts.PrivateKey = t.conf.PrivateKey setts.PrivateKey = m.conf.PrivateKey
} }
if setts.Enabled { if setts.Enabled {
@ -276,75 +316,74 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
return return
} }
status := tlsConfigStatus{} // Skip the error check, since we are only interested in the value of
if tlsLoadConfig(&setts.tlsConfigSettings, &status) { // status.WarningValidation.
status = validateCertificates(string(setts.CertificateChainData), string(setts.PrivateKeyData), setts.ServerName) status := &tlsConfigStatus{}
} _ = loadTLSConf(&setts.tlsConfigSettings, status)
resp := tlsConfig{
data := tlsConfig{
tlsConfigSettingsExt: setts, tlsConfigSettingsExt: setts,
tlsConfigStatus: status, tlsConfigStatus: status,
} }
marshalTLS(w, r, data) marshalTLS(w, r, resp)
} }
func (t *TLSMod) setConfig(newConf tlsConfigSettings, status tlsConfigStatus) (restartHTTPS bool) { func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
t.confLock.Lock() m.confLock.Lock()
defer t.confLock.Unlock() defer m.confLock.Unlock()
// Reset the DNSCrypt data before comparing, since we currently do not // Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend. // accept these from the frontend.
// //
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig. // TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
newConf.DNSCryptConfigFile = t.conf.DNSCryptConfigFile newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
newConf.PortDNSCrypt = t.conf.PortDNSCrypt newConf.PortDNSCrypt = m.conf.PortDNSCrypt
if !cmp.Equal(t.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
log.Info("tls config has changed, restarting https server") log.Info("tls config has changed, restarting https server")
restartHTTPS = true restartHTTPS = true
} else { } else {
log.Info("tls config has not changed") log.Info("tls: config has not changed")
} }
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
t.conf.Enabled = newConf.Enabled m.conf.Enabled = newConf.Enabled
t.conf.ServerName = newConf.ServerName m.conf.ServerName = newConf.ServerName
t.conf.ForceHTTPS = newConf.ForceHTTPS m.conf.ForceHTTPS = newConf.ForceHTTPS
t.conf.PortHTTPS = newConf.PortHTTPS m.conf.PortHTTPS = newConf.PortHTTPS
t.conf.PortDNSOverTLS = newConf.PortDNSOverTLS m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
t.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
t.conf.CertificateChain = newConf.CertificateChain m.conf.CertificateChain = newConf.CertificateChain
t.conf.CertificatePath = newConf.CertificatePath m.conf.CertificatePath = newConf.CertificatePath
t.conf.CertificateChainData = newConf.CertificateChainData m.conf.CertificateChainData = newConf.CertificateChainData
t.conf.PrivateKey = newConf.PrivateKey m.conf.PrivateKey = newConf.PrivateKey
t.conf.PrivateKeyPath = newConf.PrivateKeyPath m.conf.PrivateKeyPath = newConf.PrivateKeyPath
t.conf.PrivateKeyData = newConf.PrivateKeyData m.conf.PrivateKeyData = newConf.PrivateKeyData
t.status = status m.status = status
return restartHTTPS return restartHTTPS
} }
func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
data, err := unmarshalTLS(r) req, err := unmarshalTLS(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return return
} }
if data.PrivateKeySaved { if req.PrivateKeySaved {
data.PrivateKey = t.conf.PrivateKey req.PrivateKey = m.conf.PrivateKey
} }
if data.Enabled { if req.Enabled {
err = validatePorts( err = validatePorts(
tcpPort(config.BindPort), tcpPort(config.BindPort),
tcpPort(config.BetaBindPort), tcpPort(config.BetaBindPort),
tcpPort(data.PortHTTPS), tcpPort(req.PortHTTPS),
tcpPort(data.PortDNSOverTLS), tcpPort(req.PortDNSOverTLS),
tcpPort(data.PortDNSCrypt), tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port), udpPort(config.DNS.Port),
udpPort(data.PortDNSOverQUIC), udpPort(req.PortDNSOverQUIC),
) )
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -354,33 +393,33 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
} }
// TODO(e.burkov): Investigate and perhaps check other ports. // TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(data.PortHTTPS) { if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error( aghhttp.Error(
r, r,
w, w,
http.StatusBadRequest, http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it", "port %d is not available, cannot enable https on it",
data.PortHTTPS, req.PortHTTPS,
) )
return return
} }
status := tlsConfigStatus{} status := &tlsConfigStatus{}
if !tlsLoadConfig(&data.tlsConfigSettings, &status) { err = loadTLSConf(&req.tlsConfigSettings, status)
data2 := tlsConfig{ if err != nil {
tlsConfigSettingsExt: data, resp := tlsConfig{
tlsConfigStatus: t.status, tlsConfigSettingsExt: req,
tlsConfigStatus: status,
} }
marshalTLS(w, r, data2)
marshalTLS(w, r, resp)
return return
} }
status = validateCertificates(string(data.CertificateChainData), string(data.PrivateKeyData), data.ServerName) restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
m.setCertFileTime()
restartHTTPS := t.setConfig(data.tlsConfigSettings, status)
t.setCertFileTime()
onConfigModified() onConfigModified()
err = reconfigureDNSServer() err = reconfigureDNSServer()
@ -390,12 +429,12 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
return return
} }
data2 := tlsConfig{ resp := tlsConfig{
tlsConfigSettingsExt: data, tlsConfigSettingsExt: req,
tlsConfigStatus: t.status, tlsConfigStatus: m.status,
} }
marshalTLS(w, r, data2) marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok { if f, ok := w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
@ -406,7 +445,7 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
// same reason. // same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
Context.web.TLSConfigChanged(context.Background(), data.tlsConfigSettings) Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
}() }()
} }
} }
@ -443,89 +482,105 @@ func validatePorts(
return nil return nil
} }
func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error { // validateCertChain validates the certificate chain and sets data in status.
log.Tracef("TLS: got certificate: %d bytes", len(certChain)) // The returned error is also set in status.WarningValidation.
func validateCertChain(status *tlsConfigStatus, certChain []byte, serverName string) (err error) {
defer func() {
if err != nil {
status.WarningValidation = err.Error()
}
}()
// now do a more extended validation log.Debug("tls: got certificate chain: %d bytes", len(certChain))
var certs []*pem.Block // PEM-encoded certificates
pemblock := []byte(certChain) var certs []*pem.Block
pemblock := certChain
for { for {
var decoded *pem.Block var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock) decoded, pemblock = pem.Decode(pemblock)
if decoded == nil { if decoded == nil {
break break
} }
if decoded.Type == "CERTIFICATE" { if decoded.Type == "CERTIFICATE" {
certs = append(certs, decoded) certs = append(certs, decoded)
} }
} }
var parsedCerts []*x509.Certificate parsedCerts, err := parsePEMCerts(certs)
if err != nil {
for _, cert := range certs { return err
parsed, err := x509.ParseCertificate(cert.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
return errors.Error(data.WarningValidation)
}
parsedCerts = append(parsedCerts, parsed)
} }
if len(parsedCerts) == 0 { status.ValidCert = true
data.WarningValidation = "You have specified an empty certificate"
return errors.Error(data.WarningValidation)
}
data.ValidCert = true
// spew.Dump(parsedCerts)
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
DNSName: serverName, DNSName: serverName,
Roots: Context.tlsRoots, Roots: Context.tlsRoots,
} }
log.Printf("number of certs - %d", len(parsedCerts)) log.Info("tls: number of certs: %d", len(parsedCerts))
if len(parsedCerts) > 1 {
// set up an intermediate pool := x509.NewCertPool()
pool := x509.NewCertPool() for _, cert := range parsedCerts[1:] {
for _, cert := range parsedCerts[1:] { log.Info("tls: got an intermediate cert")
log.Printf("got an intermediate cert") pool.AddCert(cert)
pool.AddCert(cert)
}
opts.Intermediates = pool
} }
// TODO: save it as a warning rather than error it out -- shouldn't be a big problem opts.Intermediates = pool
mainCert := parsedCerts[0] mainCert := parsedCerts[0]
_, err := mainCert.Verify(opts) _, err = mainCert.Verify(opts)
if err != nil { if err != nil {
// let self-signed certs through // Let self-signed certs through and don't return this error.
data.WarningValidation = fmt.Sprintf("Your certificate does not verify: %s", err) status.WarningValidation = fmt.Sprintf("certificate does not verify: %s", err)
} else { } else {
data.ValidChain = true status.ValidChain = true
} }
// spew.Dump(chains)
// update status
if mainCert != nil { if mainCert != nil {
notAfter := mainCert.NotAfter status.Subject = mainCert.Subject.String()
data.Subject = mainCert.Subject.String() status.Issuer = mainCert.Issuer.String()
data.Issuer = mainCert.Issuer.String() status.NotAfter = mainCert.NotAfter
data.NotAfter = notAfter status.NotBefore = mainCert.NotBefore
data.NotBefore = mainCert.NotBefore status.DNSNames = mainCert.DNSNames
data.DNSNames = mainCert.DNSNames
} }
return nil return nil
} }
func validatePkey(data *tlsConfigStatus, pkey string) error { // parsePEMCerts parses multiple PEM-encoded certificates.
// now do a more extended validation func parsePEMCerts(certs []*pem.Block) (parsedCerts []*x509.Certificate, err error) {
var key *pem.Block // PEM-encoded certificates for i, cert := range certs {
var parsed *x509.Certificate
parsed, err = x509.ParseCertificate(cert.Bytes)
if err != nil {
return nil, fmt.Errorf("parsing certificate at index %d: %w", i, err)
}
// go through all pem blocks, but take first valid pem block and drop the rest parsedCerts = append(parsedCerts, parsed)
}
if len(parsedCerts) == 0 {
return nil, errors.Error("empty certificate")
}
return parsedCerts, nil
}
// validatePKey validates the private key and sets data in status. The returned
// error is also set in status.WarningValidation.
func validatePKey(status *tlsConfigStatus, pkey []byte) (err error) {
defer func() {
if err != nil {
status.WarningValidation = err.Error()
}
}()
var key *pem.Block
// Go through all pem blocks, but take first valid pem block and drop the
// rest.
pemblock := []byte(pkey) pemblock := []byte(pkey)
for { for {
var decoded *pem.Block var decoded *pem.Block
@ -542,61 +597,77 @@ func validatePkey(data *tlsConfigStatus, pkey string) error {
} }
if key == nil { if key == nil {
data.WarningValidation = "No valid keys were found" return errors.Error("no valid keys were found")
return errors.Error(data.WarningValidation)
} }
// parse the decoded key
_, keyType, err := parsePrivateKey(key.Bytes) _, keyType, err := parsePrivateKey(key.Bytes)
if err != nil { if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err) return fmt.Errorf("parsing private key: %w", err)
return errors.Error(data.WarningValidation)
} else if keyType == keyTypeED25519 {
data.WarningValidation = "ED25519 keys are not supported by browsers; " +
"did you mean to use X25519 for key exchange?"
return errors.Error(data.WarningValidation)
} }
data.ValidKey = true if keyType == keyTypeED25519 {
data.KeyType = keyType return errors.Error(
"ED25519 keys are not supported by browsers; " +
"did you mean to use X25519 for key exchange?",
)
}
status.ValidKey = true
status.KeyType = keyType
return nil return nil
} }
// validateCertificates processes certificate data and its private key. All // validateCertificates processes certificate data and its private key. All
// parameters are optional. On error, validateCertificates returns a partially // parameters are optional. status must not be nil. The returned error is also
// set object with field WarningValidation containing error description. // set in status.WarningValidation.
func validateCertificates(certChain, pkey, serverName string) tlsConfigStatus { func validateCertificates(
var data tlsConfigStatus status *tlsConfigStatus,
certChain []byte,
// check only public certificate separately from the key pkey []byte,
if certChain != "" { serverName string,
if verifyCertChain(&data, certChain, serverName) != nil { ) (err error) {
return data defer func() {
// Capitalize the warning for the UI. Assume that warnings are all
// ASCII-only.
//
// TODO(a.garipov): Figure out a better way to do this. Perhaps a
// custom string or error type.
if w := status.WarningValidation; w != "" {
status.WarningValidation = strings.ToUpper(w[:1]) + w[1:]
} }
} }()
// validate private key (right now the only validation possible is just parsing it) // Check only the public certificate separately from the key.
if pkey != "" { if len(certChain) > 0 {
if validatePkey(&data, pkey) != nil { err = validateCertChain(status, certChain, serverName)
return data
}
}
// if both are set, validate both in unison
if pkey != "" && certChain != "" {
_, err := tls.X509KeyPair([]byte(certChain), []byte(pkey))
if err != nil { if err != nil {
data.WarningValidation = fmt.Sprintf("Invalid certificate or key: %s", err) return err
return data
} }
data.ValidPair = true
} }
return data // Validate the private key by parsing it.
if len(pkey) > 0 {
err = validatePKey(status, pkey)
if err != nil {
return err
}
}
// If both are set, validate together.
if len(certChain) > 0 && len(pkey) > 0 {
_, err = tls.X509KeyPair(certChain, pkey)
if err != nil {
err = fmt.Errorf("certificate-key pair: %w", err)
status.WarningValidation = err.Error()
return err
}
status.ValidPair = true
}
return nil
} }
// Key types. // Key types.
@ -691,9 +762,9 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
_ = aghhttp.WriteJSONResponse(w, r, data) _ = aghhttp.WriteJSONResponse(w, r, data)
} }
// registerWebHandlers registers HTTP handlers for TLS configuration // registerWebHandlers registers HTTP handlers for TLS configuration.
func (t *TLSMod) registerWebHandlers() { func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", t.handleTLSStatus) httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure) httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate) httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
} }

View File

@ -7,8 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const ( var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
CertificateChain = `-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3 BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
@ -21,8 +20,9 @@ eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ== Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`)
PrivateKey = `-----BEGIN PRIVATE KEY-----
var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
@ -37,36 +37,43 @@ An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----` -----END PRIVATE KEY-----`)
)
func TestValidateCertificates(t *testing.T) { func TestValidateCertificates(t *testing.T) {
t.Run("bad_certificate", func(t *testing.T) { t.Run("bad_certificate", func(t *testing.T) {
data := validateCertificates("bad cert", "", "") status := &tlsConfigStatus{}
assert.NotEmpty(t, data.WarningValidation) err := validateCertificates(status, []byte("bad cert"), nil, "")
assert.False(t, data.ValidCert) assert.Error(t, err)
assert.False(t, data.ValidChain) assert.NotEmpty(t, status.WarningValidation)
assert.False(t, status.ValidCert)
assert.False(t, status.ValidChain)
}) })
t.Run("bad_private_key", func(t *testing.T) { t.Run("bad_private_key", func(t *testing.T) {
data := validateCertificates("", "bad priv key", "") status := &tlsConfigStatus{}
assert.NotEmpty(t, data.WarningValidation) err := validateCertificates(status, nil, []byte("bad priv key"), "")
assert.False(t, data.ValidKey) assert.Error(t, err)
assert.NotEmpty(t, status.WarningValidation)
assert.False(t, status.ValidKey)
}) })
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
data := validateCertificates(CertificateChain, PrivateKey, "") status := &tlsConfigStatus{}
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z") err := validateCertificates(status, testCertChainData, testPrivateKeyData, "")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z") assert.NoError(t, err)
assert.NotEmpty(t, data.WarningValidation)
assert.True(t, data.ValidCert) notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC)
assert.False(t, data.ValidChain) notAfter := time.Date(2046, 7, 14, 9, 24, 23, 0, time.UTC)
assert.True(t, data.ValidKey)
assert.Equal(t, "RSA", data.KeyType) assert.NotEmpty(t, status.WarningValidation)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject) assert.True(t, status.ValidCert)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer) assert.False(t, status.ValidChain)
assert.Equal(t, notBefore, data.NotBefore) assert.True(t, status.ValidKey)
assert.Equal(t, notAfter, data.NotAfter) assert.Equal(t, "RSA", status.KeyType)
assert.True(t, data.ValidPair) assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Subject)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", status.Issuer)
assert.Equal(t, notBefore, status.NotBefore)
assert.Equal(t, notAfter, status.NotAfter)
assert.True(t, status.ValidPair)
}) })
} }