ipn/localapi: refactor some cert code in prep for a move
I want to move the guts (after the HTTP layer) of the certificate fetching into the ipnlocal package, out of localapi. As prep, refactor a bit: * add a method to do the fetch-from-cert-or-as-needed-with-refresh, rather than doing it in the HTTP hander * convert two methods to funcs, taking the one extra field (LocalBackend) then needed from their method receiver. One of the methods needed nothing from its receiver. This will make a future change easier to reason about. Change-Id: I2a7811e5d7246139927bb86e7db8009bf09b3be3 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
847a8cf917
commit
9be8d15979
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
"golang.org/x/crypto/acme"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/strs"
|
||||
|
@ -79,13 +80,6 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "cert access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
dir, err := h.certDir()
|
||||
if err != nil {
|
||||
h.logf("certDir: %v", err)
|
||||
http.Error(w, "failed to get cert dir", 500)
|
||||
return
|
||||
}
|
||||
|
||||
domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/")
|
||||
if !ok {
|
||||
http.Error(w, "internal handler config wired wrong", 500)
|
||||
|
@ -95,8 +89,24 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "invalid domain", 400)
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
pair, err := h.getCertPEM(r.Context(), domain)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprint(err), 500)
|
||||
return
|
||||
}
|
||||
serveKeyPair(w, r, pair)
|
||||
}
|
||||
|
||||
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
|
||||
// process, or from cache and kicking off an async ACME renewal.
|
||||
func (h *Handler) getCertPEM(ctx context.Context, domain string) (*keyPair, error) {
|
||||
logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain))
|
||||
dir, err := h.certDir()
|
||||
if err != nil {
|
||||
logf("failed to get certDir: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
traceACME := func(v any) {
|
||||
if !acmeDebug() {
|
||||
return
|
||||
|
@ -105,24 +115,22 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
|
|||
log.Printf("acme %T: %s", v, j)
|
||||
}
|
||||
|
||||
if pair, ok := h.getCertPEMCached(dir, domain, now); ok {
|
||||
if pair, ok := getCertPEMCached(dir, domain, now); ok {
|
||||
future := now.AddDate(0, 0, 14)
|
||||
if h.shouldStartDomainRenewal(dir, domain, future) {
|
||||
logf("starting async renewal")
|
||||
// Start renewal in the background.
|
||||
go h.getCertPEM(context.Background(), logf, traceACME, dir, domain, future)
|
||||
go getCertPEM(context.Background(), h.b, logf, traceACME, dir, domain, future)
|
||||
}
|
||||
serveKeyPair(w, r, pair)
|
||||
return
|
||||
return pair, nil
|
||||
}
|
||||
|
||||
pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, now)
|
||||
pair, err := getCertPEM(ctx, h.b, logf, traceACME, dir, domain, now)
|
||||
if err != nil {
|
||||
logf("getCertPEM: %v", err)
|
||||
http.Error(w, fmt.Sprint(err), 500)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
serveKeyPair(w, r, pair)
|
||||
return pair, nil
|
||||
}
|
||||
|
||||
func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool {
|
||||
|
@ -135,7 +143,7 @@ func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time)
|
|||
return false
|
||||
}
|
||||
lastRenewCheck[domain] = now
|
||||
_, ok := h.getCertPEMCached(dir, domain, future)
|
||||
_, ok := getCertPEMCached(dir, domain, future)
|
||||
return !ok
|
||||
}
|
||||
|
||||
|
@ -154,10 +162,12 @@ func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) {
|
|||
}
|
||||
}
|
||||
|
||||
// keyPair is a TLS public and private key, and whether they were obtained
|
||||
// from cache or freshly obtained.
|
||||
type keyPair struct {
|
||||
certPEM []byte
|
||||
keyPEM []byte
|
||||
cached bool
|
||||
certPEM []byte // public key, in PEM form
|
||||
keyPEM []byte // private key, in PEM form
|
||||
cached bool // whether result came from cache
|
||||
}
|
||||
|
||||
func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") }
|
||||
|
@ -166,7 +176,7 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr
|
|||
// getCertPEMCached returns a non-nil keyPair and true if a cached
|
||||
// keypair for domain exists on disk in dir that is valid at the
|
||||
// provided now time.
|
||||
func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
|
||||
func getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
|
||||
if !validLookingCertDomain(domain) {
|
||||
// Before we read files from disk using it, validate it's halfway
|
||||
// reasonable looking.
|
||||
|
@ -181,11 +191,11 @@ func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPai
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) {
|
||||
func getCertPEM(ctx context.Context, lb *ipnlocal.LocalBackend, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) {
|
||||
acmeMu.Lock()
|
||||
defer acmeMu.Unlock()
|
||||
|
||||
if p, ok := h.getCertPEMCached(dir, domain, now); ok {
|
||||
if p, ok := getCertPEMCached(dir, domain, now); ok {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
@ -223,7 +233,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
|||
}
|
||||
|
||||
// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
|
||||
st := h.b.StatusWithoutPeers()
|
||||
st := lb.StatusWithoutPeers()
|
||||
if err := checkCertDomain(st, domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -260,7 +270,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
|
|||
}
|
||||
if !ok {
|
||||
logf("starting SetDNS call...")
|
||||
err = h.b.SetDNS(ctx, key, rec)
|
||||
err = lb.SetDNS(ctx, key, rec)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue