ipn/localapi: refactor some cert code in prep for a move

I want to move the guts (after the HTTP layer) of the certificate
fetching into the ipnlocal package, out of localapi.

As prep, refactor a bit:

* add a method to do the fetch-from-cert-or-as-needed-with-refresh,
  rather than doing it in the HTTP hander

* convert two methods to funcs, taking the one extra field (LocalBackend)
  then needed from their method receiver. One of the methods needed
  nothing from its receiver.

This will make a future change easier to reason about.

Change-Id: I2a7811e5d7246139927bb86e7db8009bf09b3be3
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-07 20:49:46 -08:00 committed by Brad Fitzpatrick
parent 847a8cf917
commit 9be8d15979
1 changed files with 35 additions and 25 deletions

View File

@ -34,6 +34,7 @@ import (
"golang.org/x/crypto/acme"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnstate"
"tailscale.com/types/logger"
"tailscale.com/util/strs"
@ -79,13 +80,6 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "cert access denied", http.StatusForbidden)
return
}
dir, err := h.certDir()
if err != nil {
h.logf("certDir: %v", err)
http.Error(w, "failed to get cert dir", 500)
return
}
domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/")
if !ok {
http.Error(w, "internal handler config wired wrong", 500)
@ -95,8 +89,24 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "invalid domain", 400)
return
}
now := time.Now()
pair, err := h.getCertPEM(r.Context(), domain)
if err != nil {
http.Error(w, fmt.Sprint(err), 500)
return
}
serveKeyPair(w, r, pair)
}
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
// process, or from cache and kicking off an async ACME renewal.
func (h *Handler) getCertPEM(ctx context.Context, domain string) (*keyPair, error) {
logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain))
dir, err := h.certDir()
if err != nil {
logf("failed to get certDir: %v", err)
return nil, err
}
now := time.Now()
traceACME := func(v any) {
if !acmeDebug() {
return
@ -105,24 +115,22 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
log.Printf("acme %T: %s", v, j)
}
if pair, ok := h.getCertPEMCached(dir, domain, now); ok {
if pair, ok := getCertPEMCached(dir, domain, now); ok {
future := now.AddDate(0, 0, 14)
if h.shouldStartDomainRenewal(dir, domain, future) {
logf("starting async renewal")
// Start renewal in the background.
go h.getCertPEM(context.Background(), logf, traceACME, dir, domain, future)
go getCertPEM(context.Background(), h.b, logf, traceACME, dir, domain, future)
}
serveKeyPair(w, r, pair)
return
return pair, nil
}
pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, now)
pair, err := getCertPEM(ctx, h.b, logf, traceACME, dir, domain, now)
if err != nil {
logf("getCertPEM: %v", err)
http.Error(w, fmt.Sprint(err), 500)
return
return nil, err
}
serveKeyPair(w, r, pair)
return pair, nil
}
func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool {
@ -135,7 +143,7 @@ func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time)
return false
}
lastRenewCheck[domain] = now
_, ok := h.getCertPEMCached(dir, domain, future)
_, ok := getCertPEMCached(dir, domain, future)
return !ok
}
@ -154,10 +162,12 @@ func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) {
}
}
// keyPair is a TLS public and private key, and whether they were obtained
// from cache or freshly obtained.
type keyPair struct {
certPEM []byte
keyPEM []byte
cached bool
certPEM []byte // public key, in PEM form
keyPEM []byte // private key, in PEM form
cached bool // whether result came from cache
}
func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") }
@ -166,7 +176,7 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr
// getCertPEMCached returns a non-nil keyPair and true if a cached
// keypair for domain exists on disk in dir that is valid at the
// provided now time.
func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
func getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
if !validLookingCertDomain(domain) {
// Before we read files from disk using it, validate it's halfway
// reasonable looking.
@ -181,11 +191,11 @@ func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPai
return nil, false
}
func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) {
func getCertPEM(ctx context.Context, lb *ipnlocal.LocalBackend, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) {
acmeMu.Lock()
defer acmeMu.Unlock()
if p, ok := h.getCertPEMCached(dir, domain, now); ok {
if p, ok := getCertPEMCached(dir, domain, now); ok {
return p, nil
}
@ -223,7 +233,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
}
// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
st := h.b.StatusWithoutPeers()
st := lb.StatusWithoutPeers()
if err := checkCertDomain(st, domain); err != nil {
return nil, err
}
@ -260,7 +270,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
}
if !ok {
logf("starting SetDNS call...")
err = h.b.SetDNS(ctx, key, rec)
err = lb.SetDNS(ctx, key, rec)
if err != nil {
return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err)
}