ipn/{ipnlocal,localapi}: actually renew certs before expiry (#8731)

While our `shouldStartDomainRenewal` check is correct, `getCertPEM`
would always bail if the existing cert is not expired. Add the same
`shouldStartDomainRenewal` check to `getCertPEM` to make it proceed with
renewal when existing certs are still valid but should be renewed.

The extra check is expensive (ARI request towards LetsEncrypt), so cache
the last check result for 1hr to not degrade `tailscale serve`
performance.

Also, asynchronous renewal is great for `tailscale serve` but confusing
for `tailscale cert`. Add an explicit flag to `GetCertPEM` to force a
synchronous renewal for `tailscale cert`.

Fixes #8725

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
Andrew Lytvynov 2023-07-27 12:29:40 -07:00 committed by GitHub
parent aa37be70cf
commit c1ecae13ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 40 deletions

View File

@ -53,8 +53,8 @@ var (
// populate the on-disk cache and the rest should use that.
acmeMu sync.Mutex
renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time
lastRenewCheck = map[string]time.Time{}
renewMu sync.Mutex // lock order: acmeMu before renewMu
renewCertAt = map[string]time.Time{}
)
// certDir returns (creating if needed) the directory in which cached
@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) {
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
// process, or from cache and kicking off an async ACME renewal.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the
// ACME process. ACME process is used for new domain certs, existing expired
// certs or existing certs that should get renewed due to upcoming expiry.
//
// syncRenewal changes renewal behavior for existing certs that are still valid
// but need renewal. When syncRenewal is set, the method blocks until a new
// cert is issued. When syncRenewal is not set, existing cert is returned right
// away and renewal is kicked off in a background goroutine.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
if !validLookingCertDomain(domain) {
return nil, errors.New("invalid domain")
}
@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
if err != nil {
logf("error checking for certificate renewal: %v", err)
} else if shouldRenew {
} else if !shouldRenew {
return pair, nil
}
if !syncRenewal {
logf("starting async renewal")
// Start renewal in the background.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
}
return pair, nil
// Synchronous renewal happens below.
}
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
renewMu.Lock()
defer renewMu.Unlock()
if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute {
// We checked very recently. Don't bother reparsing &
// validating the x509 cert.
return false, nil
if renewAt, ok := renewCertAt[domain]; ok {
return now.After(renewAt), nil
}
lastRenewCheck[domain] = now
renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair)
renewTime, err := b.domainRenewalTimeByARI(cs, pair)
if err != nil {
// Log any ARI failure and fall back to checking for renewal by expiry.
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
} else {
return renew, nil
renewTime, err = b.domainRenewalTimeByExpiry(pair)
if err != nil {
return false, err
}
}
return b.shouldStartDomainRenewalByExpiry(now, pair)
renewCertAt[domain] = renewTime
return now.After(renewTime), nil
}
func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) {
func (b *LocalBackend) domainRenewed(domain string) {
renewMu.Lock()
defer renewMu.Unlock()
delete(renewCertAt, domain)
}
func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
block, _ := pem.Decode(pair.CertPEM)
if block == nil {
return false, fmt.Errorf("parsing certificate PEM")
return time.Time{}, fmt.Errorf("parsing certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return false, fmt.Errorf("parsing certificate: %w", err)
return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
}
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
if certLifetime < 0 {
return false, fmt.Errorf("negative certificate lifetime %v", certLifetime)
return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime)
}
// Per https://github.com/tailscale/tailscale/issues/8204, check
@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS
// Encrypt.
renewalDuration := certLifetime * 2 / 3
renewAt := cert.NotBefore.Add(renewalDuration)
if now.After(renewAt) {
return true, nil
}
return false, nil
return renewAt, nil
}
func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) {
func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) {
var blocks []*pem.Block
rest := pair.CertPEM
for len(rest) > 0 {
var block *pem.Block
block, rest = pem.Decode(rest)
if block == nil {
return false, fmt.Errorf("parsing certificate PEM")
return time.Time{}, fmt.Errorf("parsing certificate PEM")
}
blocks = append(blocks, block)
}
if len(blocks) < 2 {
return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
}
ac, err := acmeClient(cs)
if err != nil {
return false, err
return time.Time{}, err
}
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
defer cancel()
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
if err != nil {
return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
}
if acmeDebug() {
b.logf("acme: ARI response: %+v", ri)
@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
return now.After(renewTime), nil
return renewTime, nil
}
// certStore provides a way to perist and retrieve TLS certificates.
@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
acmeMu.Lock()
defer acmeMu.Unlock()
// In case this method was triggered multiple times in parallel (when
// serving incoming requests), check whether one of the other goroutines
// already renewed the cert before us.
if p, err := getCertPEMCached(cs, domain, now); err == nil {
return p, nil
// shouldStartDomainRenewal caches its result so it's OK to call this
// frequently.
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
if err != nil {
logf("error checking for certificate renewal: %v", err)
} else if !shouldRenew {
return p, nil
}
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
return nil, err
}
@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
return nil, err
}
b.domainRenewed(domain)
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
}

View File

@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
CertPEM, KeyPEM []byte
}
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
return nil, errors.New("not implemented for js/wasm")
}

View File

@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
reset := func() {
renewMu.Lock()
defer renewMu.Unlock()
maps.Clear(lastRenewCheck)
maps.Clear(renewCertAt)
}
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
reset()
ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{
ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: subject,
NotBefore: tt.notBefore,
@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
}
} else {
if ret != tt.want {
t.Errorf("got ret=%v, want %v", ret, tt.want)
renew := now.After(ret)
if renew != tt.want {
t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
}
}
})

View File

@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, sni)
pair, err := b.GetCertPEM(ctx, sni, false)
if err != nil {
return nil, err
}
@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
pair, err := b.GetCertPEM(ctx, hi.ServerName)
pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
if err != nil {
return nil, err
}

View File

@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal handler config wired wrong", 500)
return
}
pair, err := h.b.GetCertPEM(r.Context(), domain)
pair, err := h.b.GetCertPEM(r.Context(), domain, true)
if err != nil {
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
// GetCertPEM (and everywhere) should carry info info to get whether