diff --git a/cmd/tailscale/cli/cert.go b/cmd/tailscale/cli/cert.go index 305368892..87c9357d6 100644 --- a/cmd/tailscale/cli/cert.go +++ b/cmd/tailscale/cli/cert.go @@ -23,11 +23,11 @@ import ( "log" "os" "path/filepath" - "strconv" "strings" "golang.org/x/crypto/acme" "tailscale.com/client/tailscale" + "tailscale.com/ipn/ipnstate" ) func jout(v interface{}) { @@ -38,7 +38,45 @@ func jout(v interface{}) { fmt.Printf("%T: %s\n", v, j) } +func checkCertDomain(st *ipnstate.Status, domain string) error { + if domain == "" { + return errors.New("missing domain name") + } + for _, d := range st.CertDomains { + if d == domain { + return nil + } + } + // Transitional way while server doesn't yet populate CertDomains: also permit the client + // attempting Self.DNSName. + okay := st.CertDomains[:len(st.CertDomains):len(st.CertDomains)] + if st.Self != nil { + if v := strings.Trim(st.Self.DNSName, "."); v != "" { + if v == domain { + return nil + } + okay = append(okay, v) + } + } + switch len(okay) { + case 0: + return errors.New("your Tailscale account does not support getting TLS certs") + case 1: + return fmt.Errorf("invalid domain %q; only %q is permitted", domain, okay[0]) + default: + return fmt.Errorf("invalid domain %q; must be one of %q", domain, okay) + } +} + func debugGetCert(ctx context.Context, cert string) error { + st, err := tailscale.Status(ctx) + if err != nil { + return fmt.Errorf("getting tailscale status: %w", err) + } + if err := checkCertDomain(st, cert); err != nil { + return err + } + key, err := acmeKey() if err != nil { return err @@ -46,18 +84,31 @@ func debugGetCert(ctx context.Context, cert string) error { ac := &acme.Client{ Key: key, } - d, err := ac.Discover(ctx) - if err != nil { - return err - } - jout(d) - if reg, _ := strconv.ParseBool(os.Getenv("TS_DEBUG_ACME_REGISTER")); reg { - acct, err := ac.Register(ctx, new(acme.Account), acme.AcceptTOS) - if err != nil { - return fmt.Errorf("Register: %v", err) + logf := log.Printf + + a, err := ac.GetReg(ctx, "unused") + switch { + case err == nil: + // Great, already registered. + logf("Already had ACME account.") + case err == acme.ErrNoAccount: + a, err = ac.Register(ctx, new(acme.Account), acme.AcceptTOS) + if err == acme.ErrAccountAlreadyExists { + // Potential race. Double check. + a, err = ac.GetReg(ctx, "unused") } - jout(acct) + if err != nil { + return fmt.Errorf("acme.Register: %w", err) + } + logf("Registered ACME account.") + jout(a) + default: + return fmt.Errorf("acme.GetReg: %w", err) + + } + if a.Status != acme.StatusValid { + return fmt.Errorf("unexpected ACME account status %q", a.Status) } order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: cert}})