ipn/{ipnlocal,localapi}: actually renew certs before expiry (#8731)
While our `shouldStartDomainRenewal` check is correct, `getCertPEM` would always bail if the existing cert is not expired. Add the same `shouldStartDomainRenewal` check to `getCertPEM` to make it proceed with renewal when existing certs are still valid but should be renewed. The extra check is expensive (ARI request towards LetsEncrypt), so cache the last check result for 1hr to not degrade `tailscale serve` performance. Also, asynchronous renewal is great for `tailscale serve` but confusing for `tailscale cert`. Add an explicit flag to `GetCertPEM` to force a synchronous renewal for `tailscale cert`. Fixes #8725 Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
parent
aa37be70cf
commit
c1ecae13ab
|
@ -53,8 +53,8 @@ var (
|
|||
// populate the on-disk cache and the rest should use that.
|
||||
acmeMu sync.Mutex
|
||||
|
||||
renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time
|
||||
lastRenewCheck = map[string]time.Time{}
|
||||
renewMu sync.Mutex // lock order: acmeMu before renewMu
|
||||
renewCertAt = map[string]time.Time{}
|
||||
)
|
||||
|
||||
// certDir returns (creating if needed) the directory in which cached
|
||||
|
@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) {
|
|||
|
||||
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
|
||||
|
||||
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
|
||||
// process, or from cache and kicking off an async ACME renewal.
|
||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
|
||||
// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the
|
||||
// ACME process. ACME process is used for new domain certs, existing expired
|
||||
// certs or existing certs that should get renewed due to upcoming expiry.
|
||||
//
|
||||
// syncRenewal changes renewal behavior for existing certs that are still valid
|
||||
// but need renewal. When syncRenewal is set, the method blocks until a new
|
||||
// cert is issued. When syncRenewal is not set, existing cert is returned right
|
||||
// away and renewal is kicked off in a background goroutine.
|
||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
|
||||
if !validLookingCertDomain(domain) {
|
||||
return nil, errors.New("invalid domain")
|
||||
}
|
||||
|
@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
|||
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
|
||||
if err != nil {
|
||||
logf("error checking for certificate renewal: %v", err)
|
||||
} else if shouldRenew {
|
||||
} else if !shouldRenew {
|
||||
return pair, nil
|
||||
}
|
||||
if !syncRenewal {
|
||||
logf("starting async renewal")
|
||||
// Start renewal in the background.
|
||||
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
|
||||
}
|
||||
return pair, nil
|
||||
// Synchronous renewal happens below.
|
||||
}
|
||||
|
||||
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
|
||||
|
@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
|||
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
||||
renewMu.Lock()
|
||||
defer renewMu.Unlock()
|
||||
if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute {
|
||||
// We checked very recently. Don't bother reparsing &
|
||||
// validating the x509 cert.
|
||||
return false, nil
|
||||
if renewAt, ok := renewCertAt[domain]; ok {
|
||||
return now.After(renewAt), nil
|
||||
}
|
||||
lastRenewCheck[domain] = now
|
||||
|
||||
renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair)
|
||||
renewTime, err := b.domainRenewalTimeByARI(cs, pair)
|
||||
if err != nil {
|
||||
// Log any ARI failure and fall back to checking for renewal by expiry.
|
||||
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
|
||||
} else {
|
||||
return renew, nil
|
||||
renewTime, err = b.domainRenewalTimeByExpiry(pair)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return b.shouldStartDomainRenewalByExpiry(now, pair)
|
||||
renewCertAt[domain] = renewTime
|
||||
return now.After(renewTime), nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
||||
func (b *LocalBackend) domainRenewed(domain string) {
|
||||
renewMu.Lock()
|
||||
defer renewMu.Unlock()
|
||||
delete(renewCertAt, domain)
|
||||
}
|
||||
|
||||
func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
|
||||
block, _ := pem.Decode(pair.CertPEM)
|
||||
if block == nil {
|
||||
return false, fmt.Errorf("parsing certificate PEM")
|
||||
return time.Time{}, fmt.Errorf("parsing certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing certificate: %w", err)
|
||||
return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
|
||||
}
|
||||
|
||||
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||
if certLifetime < 0 {
|
||||
return false, fmt.Errorf("negative certificate lifetime %v", certLifetime)
|
||||
return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime)
|
||||
}
|
||||
|
||||
// Per https://github.com/tailscale/tailscale/issues/8204, check
|
||||
|
@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS
|
|||
// Encrypt.
|
||||
renewalDuration := certLifetime * 2 / 3
|
||||
renewAt := cert.NotBefore.Add(renewalDuration)
|
||||
|
||||
if now.After(renewAt) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
return renewAt, nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
||||
func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) {
|
||||
var blocks []*pem.Block
|
||||
rest := pair.CertPEM
|
||||
for len(rest) > 0 {
|
||||
var block *pem.Block
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
return false, fmt.Errorf("parsing certificate PEM")
|
||||
return time.Time{}, fmt.Errorf("parsing certificate PEM")
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
if len(blocks) < 2 {
|
||||
return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
|
||||
return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
|
||||
}
|
||||
ac, err := acmeClient(cs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return time.Time{}, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
|
||||
return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
|
||||
}
|
||||
if acmeDebug() {
|
||||
b.logf("acme: ARI response: %+v", ri)
|
||||
|
@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time
|
|||
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
|
||||
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
|
||||
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
|
||||
return now.After(renewTime), nil
|
||||
return renewTime, nil
|
||||
}
|
||||
|
||||
// certStore provides a way to perist and retrieve TLS certificates.
|
||||
|
@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
|
|||
acmeMu.Lock()
|
||||
defer acmeMu.Unlock()
|
||||
|
||||
// In case this method was triggered multiple times in parallel (when
|
||||
// serving incoming requests), check whether one of the other goroutines
|
||||
// already renewed the cert before us.
|
||||
if p, err := getCertPEMCached(cs, domain, now); err == nil {
|
||||
return p, nil
|
||||
// shouldStartDomainRenewal caches its result so it's OK to call this
|
||||
// frequently.
|
||||
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
|
||||
if err != nil {
|
||||
logf("error checking for certificate renewal: %v", err)
|
||||
} else if !shouldRenew {
|
||||
return p, nil
|
||||
}
|
||||
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
|
|||
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.domainRenewed(domain)
|
||||
|
||||
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
|
||||
}
|
||||
|
|
|
@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
|
|||
CertPEM, KeyPEM []byte
|
||||
}
|
||||
|
||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) {
|
||||
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
|
||||
return nil, errors.New("not implemented for js/wasm")
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
|||
reset := func() {
|
||||
renewMu.Lock()
|
||||
defer renewMu.Unlock()
|
||||
maps.Clear(lastRenewCheck)
|
||||
maps.Clear(renewCertAt)
|
||||
}
|
||||
|
||||
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
|
||||
|
@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
reset()
|
||||
|
||||
ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{
|
||||
ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
|
||||
SerialNumber: big.NewInt(2019),
|
||||
Subject: subject,
|
||||
NotBefore: tt.notBefore,
|
||||
|
@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) {
|
|||
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
} else {
|
||||
if ret != tt.want {
|
||||
t.Errorf("got ret=%v, want %v", ret, tt.want)
|
||||
renew := now.After(ret)
|
||||
if renew != tt.want {
|
||||
t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
|
|||
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
pair, err := b.GetCertPEM(ctx, sni)
|
||||
pair, err := b.GetCertPEM(ctx, sni, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe
|
|||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
pair, err := b.GetCertPEM(ctx, hi.ServerName)
|
||||
pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "internal handler config wired wrong", 500)
|
||||
return
|
||||
}
|
||||
pair, err := h.b.GetCertPEM(r.Context(), domain)
|
||||
pair, err := h.b.GetCertPEM(r.Context(), domain, true)
|
||||
if err != nil {
|
||||
// TODO(bradfitz): 500 is a little lazy here. The errors returned from
|
||||
// GetCertPEM (and everywhere) should carry info info to get whether
|
||||
|
|
Loading…
Reference in New Issue