ipn/ipnlocal: renew certificates based on lifetime
Instead of renewing certificates based on whether or not they're expired at a fixed 14-day period in the future, renew based on whether or not we're more than 2/3 of the way through the certificate's lifetime. This properly handles shorter-lived certificates without issue. Updates #8204 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I5e82a9cadc427c010d04ce58c7f932e80dd571ea
This commit is contained in:
parent
d06fac0ede
commit
07eacdfe92
|
@ -101,11 +101,13 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
|||
}
|
||||
|
||||
if pair, err := getCertPEMCached(cs, domain, now); err == nil {
|
||||
future := now.AddDate(0, 0, 14)
|
||||
if b.shouldStartDomainRenewal(cs, domain, future) {
|
||||
shouldRenew, err := shouldStartDomainRenewal(domain, now, pair)
|
||||
if err != nil {
|
||||
logf("error checking for certificate renewal: %v", err)
|
||||
} else if shouldRenew {
|
||||
logf("starting async renewal")
|
||||
// Start renewal in the background.
|
||||
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, future)
|
||||
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
|
||||
}
|
||||
return pair, nil
|
||||
}
|
||||
|
@ -118,18 +120,41 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
|
|||
return pair, nil
|
||||
}
|
||||
|
||||
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, future time.Time) bool {
|
||||
func shouldStartDomainRenewal(domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
|
||||
renewMu.Lock()
|
||||
defer renewMu.Unlock()
|
||||
now := time.Now()
|
||||
if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute {
|
||||
// We checked very recently. Don't bother reparsing &
|
||||
// validating the x509 cert.
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
lastRenewCheck[domain] = now
|
||||
_, err := getCertPEMCached(cs, domain, future)
|
||||
return errors.Is(err, errCertExpired)
|
||||
|
||||
block, _ := pem.Decode(pair.CertPEM)
|
||||
if block == nil {
|
||||
return false, fmt.Errorf("parsing certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing certificate: %w", err)
|
||||
}
|
||||
|
||||
certLifetime := cert.NotAfter.Sub(cert.NotBefore)
|
||||
if certLifetime < 0 {
|
||||
return false, fmt.Errorf("negative certificate lifetime %v", certLifetime)
|
||||
}
|
||||
|
||||
// Per https://github.com/tailscale/tailscale/issues/8204, check
|
||||
// whether we're more than 2/3 of the way through the certificate's
|
||||
// lifetime, which is the officially-recommended best practice by Let's
|
||||
// Encrypt.
|
||||
renewalDuration := certLifetime * 2 / 3
|
||||
renewAt := cert.NotBefore.Add(renewalDuration)
|
||||
|
||||
if now.After(renewAt) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// certStore provides a way to perist and retrieve TLS certificates.
|
||||
|
|
|
@ -6,12 +6,19 @@
|
|||
package ipnlocal
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"embed"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/exp/maps"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
)
|
||||
|
||||
|
@ -100,3 +107,94 @@ func TestCertStoreRoundTrip(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldStartDomainRenewal(t *testing.T) {
|
||||
reset := func() {
|
||||
renewMu.Lock()
|
||||
defer renewMu.Unlock()
|
||||
maps.Clear(lastRenewCheck)
|
||||
}
|
||||
|
||||
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: b,
|
||||
})
|
||||
|
||||
return &TLSCertKeyPair{
|
||||
Cached: false,
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: []byte("unused"),
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Unix(1685714838, 0)
|
||||
subject := pkix.Name{
|
||||
Organization: []string{"Tailscale, Inc."},
|
||||
Country: []string{"CA"},
|
||||
Province: []string{"ON"},
|
||||
Locality: []string{"Toronto"},
|
||||
StreetAddress: []string{"290 Bremner Blvd"},
|
||||
PostalCode: []string{"M5V 3L9"},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
notBefore time.Time
|
||||
lifetime time.Duration
|
||||
want bool
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "should renew",
|
||||
notBefore: now.AddDate(0, 0, -89),
|
||||
lifetime: 90 * 24 * time.Hour,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "short-lived renewal",
|
||||
notBefore: now.AddDate(0, 0, -7),
|
||||
lifetime: 10 * 24 * time.Hour,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no renew",
|
||||
notBefore: now.AddDate(0, 0, -59), // 59 days ago == not 2/3rds of the way through 90 days yet
|
||||
lifetime: 90 * 24 * time.Hour,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reset()
|
||||
|
||||
ret, err := shouldStartDomainRenewal("example.com", now, mustMakePair(&x509.Certificate{
|
||||
SerialNumber: big.NewInt(2019),
|
||||
Subject: subject,
|
||||
NotBefore: tt.notBefore,
|
||||
NotAfter: tt.notBefore.Add(tt.lifetime),
|
||||
}))
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("wanted error, got nil")
|
||||
} else if err.Error() != tt.wantErr {
|
||||
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
} else {
|
||||
if ret != tt.want {
|
||||
t.Errorf("got ret=%v, want %v", ret, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue