diff --git a/prober/tls.go b/prober/tls.go index 97c53298e..db25c9fd5 100644 --- a/prober/tls.go +++ b/prober/tls.go @@ -12,6 +12,7 @@ import ( "io" "net" "net/http" + "net/netip" "time" "github.com/pkg/errors" @@ -23,25 +24,33 @@ const expiresSoon = 7 * 24 * time.Hour // 7 days from now // TLS returns a Probe that healthchecks a TLS endpoint. // -// The ProbeFunc connects to a hostname (host:port string), does a TLS +// The ProbeFunc connects to a hostPort (host:port string), does a TLS // handshake, verifies that the hostname matches the presented certificate, // checks certificate validity time and OCSP revocation status. -func TLS(hostname string) ProbeFunc { +func TLS(hostPort string) ProbeFunc { return func(ctx context.Context) error { - return probeTLS(ctx, hostname) + certDomain, _, err := net.SplitHostPort(hostPort) + if err != nil { + return err + } + return probeTLS(ctx, certDomain, hostPort) } } -func probeTLS(ctx context.Context, hostname string) error { - host, _, err := net.SplitHostPort(hostname) - if err != nil { - return err +// TLSWithIP is like TLS, but dials the provided dialAddr instead +// of using DNS resolution. The certDomain is the expected name in +// the cert (and the SNI name to send). +func TLSWithIP(certDomain string, dialAddr netip.AddrPort) ProbeFunc { + return func(ctx context.Context) error { + return probeTLS(ctx, certDomain, dialAddr.String()) } +} - dialer := &tls.Dialer{Config: &tls.Config{ServerName: host}} - conn, err := dialer.DialContext(ctx, "tcp", hostname) +func probeTLS(ctx context.Context, certDomain string, dialHostPort string) error { + dialer := &tls.Dialer{Config: &tls.Config{ServerName: certDomain}} + conn, err := dialer.DialContext(ctx, "tcp", dialHostPort) if err != nil { - return fmt.Errorf("connecting to %q: %w", hostname, err) + return fmt.Errorf("connecting to %q: %w", dialHostPort, err) } defer conn.Close() diff --git a/prober/tls_test.go b/prober/tls_test.go index 6c6601ee0..5bfb739db 100644 --- a/prober/tls_test.go +++ b/prober/tls_test.go @@ -85,7 +85,7 @@ func TestTLSConnection(t *testing.T) { srv.StartTLS() defer srv.Close() - err = probeTLS(context.Background(), srv.Listener.Addr().String()) + err = probeTLS(context.Background(), "fail.example.com", srv.Listener.Addr().String()) // The specific error message here is platform-specific ("certificate is not trusted" // on macOS and "certificate signed by unknown authority" on Linux), so only check // that it contains the word 'certificate'.