ipn/{ipnlocal,localapi}: actually renew certs before expiry (#8731)
While our `shouldStartDomainRenewal` check is correct, `getCertPEM` would always bail if the existing cert is not expired. Add the same `shouldStartDomainRenewal` check to `getCertPEM` to make it proceed with renewal when existing certs are still valid but should be renewed. The extra check is expensive (ARI request towards LetsEncrypt), so cache the last check result for 1hr to not degrade `tailscale serve` performance. Also, asynchronous renewal is great for `tailscale serve` but confusing for `tailscale cert`. Add an explicit flag to `GetCertPEM` to force a synchronous renewal for `tailscale cert`. Fixes #8725 Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
parent
aa37be70cf
commit
c1ecae13ab
|
@ -53,8 +53,8 @@ var (
|
||||||
// populate the on-disk cache and the rest should use that.
|
// populate the on-disk cache and the rest should use that.
|
||||||
acmeMu sync.Mutex
|
acmeMu sync.Mutex
|
||||||
|
|
||||||
renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time
|
renewMu sync.Mutex // lock order: acmeMu before renewMu
|
||||||
lastRenewCheck = map[string]time.Time{}
|
renewCertAt = map[string]time.Time{}
|
||||||
)
|
)
|
||||||
|
|
||||||
// certDir returns (creating if needed) the directory in which cached
|
// certDir returns (creating if needed) the directory in which cached
|
||||||
|
@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) {
|
||||||
|
|
||||||
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
|
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
|
||||||
|
|
||||||
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
|
// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the
|
||||||
// process, or from cache and kicking off an async ACME renewal.
|
// ACME process. ACME process is used for new domain certs, existing expired
|
||||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
|
// certs or existing certs that should get renewed due to upcoming expiry.
|
||||||
|
//
|
||||||
|
// syncRenewal changes renewal behavior for existing certs that are still valid
|
||||||
|
// but need renewal. When syncRenewal is set, the method blocks until a new
|
||||||
|
// cert is issued. When syncRenewal is not set, existing cert is returned right
|
||||||
|
// away and renewal is kicked off in a background goroutine.
|
||||||
|
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
|
||||||
if !validLookingCertDomain(domain) {
|
if !validLookingCertDomain(domain) {
|
||||||
return nil, errors.New("invalid domain")
|
return nil, errors.New("invalid domain")
|
||||||
}
|
}
|
||||||
|
@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
||||||
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
|
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("error checking for certificate renewal: %v", err)
|
logf("error checking for certificate renewal: %v", err)
|
||||||
} else if shouldRenew {
|
} else if !shouldRenew {
|
||||||
|
return pair, nil
|
||||||
|
}
|
||||||
|
if !syncRenewal {
|
||||||
logf("starting async renewal")
|
logf("starting async renewal")
|
||||||
// Start renewal in the background.
|
// Start renewal in the background.
|
||||||
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
|
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
|
||||||
}
|
}
|
||||||
return pair, nil
|
// Synchronous renewal happens below.
|
||||||
}
|
}
|
||||||
|
|
||||||
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
|
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
|
||||||
|
@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
||||||
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
||||||
renewMu.Lock()
|
renewMu.Lock()
|
||||||
defer renewMu.Unlock()
|
defer renewMu.Unlock()
|
||||||
if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute {
|
if renewAt, ok := renewCertAt[domain]; ok {
|
||||||
// We checked very recently. Don't bother reparsing &
|
return now.After(renewAt), nil
|
||||||
// validating the x509 cert.
|
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
lastRenewCheck[domain] = now
|
|
||||||
|
|
||||||
renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair)
|
renewTime, err := b.domainRenewalTimeByARI(cs, pair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log any ARI failure and fall back to checking for renewal by expiry.
|
// Log any ARI failure and fall back to checking for renewal by expiry.
|
||||||
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
|
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
|
||||||
} else {
|
renewTime, err = b.domainRenewalTimeByExpiry(pair)
|
||||||
return renew, nil
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.shouldStartDomainRenewalByExpiry(now, pair)
|
renewCertAt[domain] = renewTime
|
||||||
|
return now.After(renewTime), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
func (b *LocalBackend) domainRenewed(domain string) {
|
||||||
|
renewMu.Lock()
|
||||||
|
defer renewMu.Unlock()
|
||||||
|
delete(renewCertAt, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
|
||||||
block, _ := pem.Decode(pair.CertPEM)
|
block, _ := pem.Decode(pair.CertPEM)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return false, fmt.Errorf("parsing certificate PEM")
|
return time.Time{}, fmt.Errorf("parsing certificate PEM")
|
||||||
}
|
}
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("parsing certificate: %w", err)
|
return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
|
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||||
if certLifetime < 0 {
|
if certLifetime < 0 {
|
||||||
return false, fmt.Errorf("negative certificate lifetime %v", certLifetime)
|
return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Per https://github.com/tailscale/tailscale/issues/8204, check
|
// Per https://github.com/tailscale/tailscale/issues/8204, check
|
||||||
|
@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS
|
||||||
// Encrypt.
|
// Encrypt.
|
||||||
renewalDuration := certLifetime * 2 / 3
|
renewalDuration := certLifetime * 2 / 3
|
||||||
renewAt := cert.NotBefore.Add(renewalDuration)
|
renewAt := cert.NotBefore.Add(renewalDuration)
|
||||||
|
return renewAt, nil
|
||||||
if now.After(renewAt) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) {
|
||||||
var blocks []*pem.Block
|
var blocks []*pem.Block
|
||||||
rest := pair.CertPEM
|
rest := pair.CertPEM
|
||||||
for len(rest) > 0 {
|
for len(rest) > 0 {
|
||||||
var block *pem.Block
|
var block *pem.Block
|
||||||
block, rest = pem.Decode(rest)
|
block, rest = pem.Decode(rest)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return false, fmt.Errorf("parsing certificate PEM")
|
return time.Time{}, fmt.Errorf("parsing certificate PEM")
|
||||||
}
|
}
|
||||||
blocks = append(blocks, block)
|
blocks = append(blocks, block)
|
||||||
}
|
}
|
||||||
if len(blocks) < 2 {
|
if len(blocks) < 2 {
|
||||||
return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
|
return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
|
||||||
}
|
}
|
||||||
ac, err := acmeClient(cs)
|
ac, err := acmeClient(cs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return time.Time{}, err
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
|
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
|
return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
|
||||||
}
|
}
|
||||||
if acmeDebug() {
|
if acmeDebug() {
|
||||||
b.logf("acme: ARI response: %+v", ri)
|
b.logf("acme: ARI response: %+v", ri)
|
||||||
|
@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time
|
||||||
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
|
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
|
||||||
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
|
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
|
||||||
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
|
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
|
||||||
return now.After(renewTime), nil
|
return renewTime, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// certStore provides a way to perist and retrieve TLS certificates.
|
// certStore provides a way to perist and retrieve TLS certificates.
|
||||||
|
@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
|
||||||
acmeMu.Lock()
|
acmeMu.Lock()
|
||||||
defer acmeMu.Unlock()
|
defer acmeMu.Unlock()
|
||||||
|
|
||||||
|
// In case this method was triggered multiple times in parallel (when
|
||||||
|
// serving incoming requests), check whether one of the other goroutines
|
||||||
|
// already renewed the cert before us.
|
||||||
if p, err := getCertPEMCached(cs, domain, now); err == nil {
|
if p, err := getCertPEMCached(cs, domain, now); err == nil {
|
||||||
return p, nil
|
// shouldStartDomainRenewal caches its result so it's OK to call this
|
||||||
|
// frequently.
|
||||||
|
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
|
||||||
|
if err != nil {
|
||||||
|
logf("error checking for certificate renewal: %v", err)
|
||||||
|
} else if !shouldRenew {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
|
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
|
||||||
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
|
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
b.domainRenewed(domain)
|
||||||
|
|
||||||
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
|
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
|
||||||
CertPEM, KeyPEM []byte
|
CertPEM, KeyPEM []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
|
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
|
||||||
return nil, errors.New("not implemented for js/wasm")
|
return nil, errors.New("not implemented for js/wasm")
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
||||||
reset := func() {
|
reset := func() {
|
||||||
renewMu.Lock()
|
renewMu.Lock()
|
||||||
defer renewMu.Unlock()
|
defer renewMu.Unlock()
|
||||||
maps.Clear(lastRenewCheck)
|
maps.Clear(renewCertAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
|
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
|
||||||
|
@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
reset()
|
reset()
|
||||||
|
|
||||||
ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{
|
ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
|
||||||
SerialNumber: big.NewInt(2019),
|
SerialNumber: big.NewInt(2019),
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
NotBefore: tt.notBefore,
|
NotBefore: tt.notBefore,
|
||||||
|
@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
||||||
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
|
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ret != tt.want {
|
renew := now.After(ret)
|
||||||
t.Errorf("got ret=%v, want %v", ret, tt.want)
|
if renew != tt.want {
|
||||||
|
t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
|
||||||
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
pair, err := b.GetCertPEM(ctx, sni)
|
pair, err := b.GetCertPEM(ctx, sni, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
pair, err := b.GetCertPEM(ctx, hi.ServerName)
|
pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "internal handler config wired wrong", 500)
|
http.Error(w, "internal handler config wired wrong", 500)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pair, err := h.b.GetCertPEM(r.Context(), domain)
|
pair, err := h.b.GetCertPEM(r.Context(), domain, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
|
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
|
||||||
// GetCertPEM (and everywhere) should carry info info to get whether
|
// GetCertPEM (and everywhere) should carry info info to get whether
|
||||||
|
|
Loading…
Reference in New Issue