diff --git a/prober/http.go b/prober/http.go new file mode 100644 index 000000000..d79b2132e --- /dev/null +++ b/prober/http.go @@ -0,0 +1,62 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package prober + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" +) + +const maxHTTPBody = 4 << 20 // MiB + +// HTTP returns a Probe that healthchecks an HTTP URL. +// +// The Probe sends a GET request for url, expects an HTTP 200 +// response, and verifies that want is present in the response +// body. If the URL is HTTPS, the probe further checks that the TLS +// certificate is good for at least the next 7 days. +func HTTP(url, wantText string) Probe { + return func(ctx context.Context) error { + return probeHTTP(ctx, url, []byte(wantText)) + } +} + +func probeHTTP(ctx context.Context, url string, want []byte) error { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("constructing request: %w", err) + } + + // Get a completely new transport each time, so we don't reuse a + // past connection. + tr := http.DefaultTransport.(*http.Transport).Clone() + defer tr.CloseIdleConnections() + c := &http.Client{ + Transport: tr, + } + + resp, err := c.Do(req) + if err != nil { + return fmt.Errorf("fetching %q: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("fetching %q: status code %d, want 200", url, resp.StatusCode) + } + + bs, err := io.ReadAll(&io.LimitedReader{resp.Body, maxHTTPBody}) + if err != nil { + return fmt.Errorf("reading body of %q: %w", url, err) + } + + if !bytes.Contains(bs, want) { + return fmt.Errorf("body of %q does not contain %q", url, want) + } + + return nil +} diff --git a/prober/prober.go b/prober/prober.go new file mode 100644 index 000000000..d4ea1f651 --- /dev/null +++ b/prober/prober.go @@ -0,0 +1,235 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package prober implements a simple blackbox prober. Each probe runs +// in its own goroutine, and run results are recorded as Prometheus +// metrics. +package prober + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "tailscale.com/metrics" +) + +// Probe is a function that probes something and reports whether the +// probe succeeded. The provided context must be used to ensure timely +// cancellation and timeout behavior. +type Probe func(context.Context) error + +// a Prober manages a set of probes and keeps track of their results. +type Prober struct { + // Time-related functions that get faked out during tests. + now func() time.Time + newTicker func(time.Duration) ticker + + // lastStart is the time, in seconds since epoch, of the last time + // each probe started a probe cycle. + lastStart metrics.LabelMap + // lastEnd is the time, in seconds since epoch, of the last time + // each probe finished a probe cycle. + lastEnd metrics.LabelMap + // lastResult records whether probes succeeded. A successful probe + // is recorded as 1, a failure as 0. + lastResult metrics.LabelMap + // lastLatency records how long the last probe cycle took for each + // probe, in milliseconds. + lastLatency metrics.LabelMap + // probeInterval records the time in seconds between successive + // runs of each probe. + // + // This is to help Prometheus figure out how long a probe should + // be failing before it fires an alert for it. To avoid random + // background noise, you want it to wait for more than 1 + // datapoint, but you also can't use a fixed interval because some + // probes might run every few seconds, while e.g. TLS certificate + // expiry might only run once a day. + // + // So, for each probe, the prober tells Prometheus how often it + // runs, so that the alert can autotune itself to eliminate noise + // without being excessively delayed. + probeInterval metrics.LabelMap + + mu sync.Mutex // protects all following fields + activeProbeCh map[string]chan struct{} +} + +// New returns a new Prober. +func New() *Prober { + return newForTest(time.Now, newRealTicker) +} + +func newForTest(now func() time.Time, newTicker func(time.Duration) ticker) *Prober { + return &Prober{ + now: now, + newTicker: newTicker, + lastStart: metrics.LabelMap{Label: "probe"}, + lastEnd: metrics.LabelMap{Label: "probe"}, + lastResult: metrics.LabelMap{Label: "probe"}, + lastLatency: metrics.LabelMap{Label: "probe"}, + probeInterval: metrics.LabelMap{Label: "probe"}, + activeProbeCh: map[string]chan struct{}{}, + } +} + +// Expvar returns the metrics for running probes. +func (p *Prober) Expvar() *metrics.Set { + ret := new(metrics.Set) + ret.Set("start_secs", &p.lastStart) + ret.Set("end_secs", &p.lastEnd) + ret.Set("result", &p.lastResult) + ret.Set("latency_millis", &p.lastLatency) + ret.Set("interval_secs", &p.probeInterval) + return ret +} + +// Run executes fun every interval, and exports probe results under probeName. +// +// fun is given a context.Context that, if obeyed, ensures that fun +// ends within interval. If fun disregards the context, it will not be +// run again until it does finish, and metrics will reflect that the +// probe function is stuck. +// +// Run returns a context.CancelFunc that stops the probe when +// invoked. Probe shutdown and removal happens-before the CancelFunc +// returns. +// +// Registering a probe under an already-registered name panics. +func (p *Prober) Run(name string, interval time.Duration, fun Probe) context.CancelFunc { + p.mu.Lock() + defer p.mu.Unlock() + ticker := p.registerLocked(name, interval) + + ctx, cancel := context.WithCancel(context.Background()) + go p.probeLoop(ctx, name, interval, ticker, fun) + + return func() { + p.mu.Lock() + stopped := p.activeProbeCh[name] + p.mu.Unlock() + cancel() + <-stopped + } +} + +// probeLoop invokes runProbe on fun every interval. The first probe +// is run after interval. +func (p *Prober) probeLoop(ctx context.Context, name string, interval time.Duration, tick ticker, fun Probe) { + defer func() { + p.unregister(name) + tick.Stop() + }() + + for { + select { + case <-tick.Chan(): + p.runProbe(ctx, name, interval, fun) + case <-ctx.Done(): + return + } + } +} + +// runProbe invokes fun and records the results. +// +// fun is invoked with a timeout slightly less than interval, so that +// the probe either succeeds or fails before the next cycle is +// scheduled to start. +func (p *Prober) runProbe(ctx context.Context, name string, interval time.Duration, fun Probe) { + start := p.start(name) + defer func() { + // Prevent a panic within one probe function from killing the + // entire prober, so that a single buggy probe doesn't destroy + // our entire ability to monitor anything. A panic is recorded + // as a probe failure, so panicking probes will trigger an + // alert for debugging. + if r := recover(); r != nil { + log.Printf("probe %s panicked: %v", name, r) + p.end(name, start, errors.New("panic")) + } + }() + timeout := time.Duration(float64(interval) * 0.8) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + err := fun(ctx) + p.end(name, start, err) + if err != nil { + log.Printf("probe %s: %v", name, err) + } +} + +func (p *Prober) registerLocked(name string, interval time.Duration) ticker { + if _, ok := p.activeProbeCh[name]; ok { + panic(fmt.Sprintf("probe named %q already registered", name)) + } + + stoppedCh := make(chan struct{}) + p.activeProbeCh[name] = stoppedCh + p.probeInterval.Get(name).Set(int64(interval.Seconds())) + // Create and return a ticker from here, while Prober is + // locked. This ensures that our fake time in tests always sees + // the new fake ticker being created before seeing that a new + // probe is registered. + return p.newTicker(interval) +} + +func (p *Prober) unregister(name string) { + p.mu.Lock() + defer p.mu.Unlock() + close(p.activeProbeCh[name]) + delete(p.activeProbeCh, name) + p.lastStart.Delete(name) + p.lastEnd.Delete(name) + p.lastResult.Delete(name) + p.lastLatency.Delete(name) + p.probeInterval.Delete(name) +} + +func (p *Prober) start(name string) time.Time { + st := p.now() + p.lastStart.Get(name).Set(st.Unix()) + return st +} + +func (p *Prober) end(name string, start time.Time, err error) { + end := p.now() + p.lastEnd.Get(name).Set(end.Unix()) + p.lastLatency.Get(name).Set(end.Sub(start).Milliseconds()) + v := int64(1) + if err != nil { + v = 0 + } + p.lastResult.Get(name).Set(v) +} + +// Reports the number of registered probes. For tests only. +func (p *Prober) activeProbes() int { + p.mu.Lock() + defer p.mu.Unlock() + return len(p.activeProbeCh) +} + +// ticker wraps a time.Ticker in a way that can be faked for tests. +type ticker interface { + Chan() <-chan time.Time + Stop() +} + +type realTicker struct { + *time.Ticker +} + +func (t *realTicker) Chan() <-chan time.Time { + return t.Ticker.C +} + +func newRealTicker(d time.Duration) ticker { + return &realTicker{time.NewTicker(d)} +} diff --git a/prober/prober_test.go b/prober/prober_test.go new file mode 100644 index 000000000..76f0264ec --- /dev/null +++ b/prober/prober_test.go @@ -0,0 +1,293 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package prober + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + "tailscale.com/syncs" + "tailscale.com/tstest" +) + +const ( + probeInterval = 10 * time.Second // So expvars that are integer numbers of seconds change + halfProbeInterval = probeInterval / 2 + quarterProbeInterval = probeInterval / 4 + convergenceTimeout = time.Second + convergenceSleep = time.Millisecond +) + +var epoch = time.Unix(0, 0) + +func TestProberTiming(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + invoked := make(chan struct{}, 1) + + notCalled := func() { + t.Helper() + select { + case <-invoked: + t.Fatal("probe was invoked earlier than expected") + default: + } + } + called := func() { + t.Helper() + select { + case <-invoked: + case <-time.After(2 * time.Second): + t.Fatal("probe wasn't invoked as expected") + } + } + + p.Run("test-probe", probeInterval, func(context.Context) error { + invoked <- struct{}{} + return nil + }) + + waitActiveProbes(t, p, 1) + + notCalled() + clk.Advance(probeInterval + halfProbeInterval) + called() + notCalled() + clk.Advance(quarterProbeInterval) + notCalled() + clk.Advance(probeInterval) + called() + notCalled() +} + +func TestProberRun(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + var ( + mu sync.Mutex + cnt int + ) + + const startingProbes = 100 + cancels := []context.CancelFunc{} + + for i := 0; i < startingProbes; i++ { + cancels = append(cancels, p.Run(fmt.Sprintf("probe%d", i), probeInterval, func(context.Context) error { + mu.Lock() + defer mu.Unlock() + cnt++ + return nil + })) + } + + checkCnt := func(want int) { + err := tstest.WaitFor(convergenceTimeout, func() error { + mu.Lock() + defer mu.Unlock() + if cnt == want { + cnt = 0 + return nil + } + return fmt.Errorf("wrong number of probe counter increments, got %d want %d", cnt, want) + }) + if err != nil { + t.Fatal(err) + } + } + + waitActiveProbes(t, p, startingProbes) + clk.Advance(probeInterval + halfProbeInterval) + checkCnt(startingProbes) + + keep := startingProbes / 2 + + for i := keep; i < startingProbes; i++ { + cancels[i]() + } + waitActiveProbes(t, p, keep) + + clk.Advance(probeInterval) + checkCnt(keep) +} + +func TestExpvar(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + const aFewMillis = 20 * time.Millisecond + var succeed syncs.AtomicBool + p.Run("probe", probeInterval, func(context.Context) error { + clk.Advance(aFewMillis) + if succeed.Get() { + return nil + } + return errors.New("failing, as instructed by test") + }) + + waitActiveProbes(t, p, 1) + clk.Advance(probeInterval + halfProbeInterval) + + waitExpInt(t, p, "start_secs/probe", int((probeInterval + halfProbeInterval).Seconds())) + waitExpInt(t, p, "end_secs/probe", int((probeInterval + halfProbeInterval).Seconds())) + waitExpInt(t, p, "interval_secs/probe", int(probeInterval.Seconds())) + waitExpInt(t, p, "latency_millis/probe", int(aFewMillis.Milliseconds())) + waitExpInt(t, p, "result/probe", 0) + + succeed.Set(true) + clk.Advance(probeInterval) + + waitExpInt(t, p, "start_secs/probe", int((probeInterval + probeInterval + halfProbeInterval).Seconds())) + waitExpInt(t, p, "end_secs/probe", int((probeInterval + probeInterval + halfProbeInterval).Seconds())) + waitExpInt(t, p, "interval_secs/probe", int(probeInterval.Seconds())) + waitExpInt(t, p, "latency_millis/probe", int(aFewMillis.Milliseconds())) + waitExpInt(t, p, "result/probe", 1) +} + +type fakeTicker struct { + ch chan time.Time + interval time.Duration + + sync.Mutex + next time.Time + stopped bool +} + +func (t *fakeTicker) Chan() <-chan time.Time { + return t.ch +} + +func (t *fakeTicker) Stop() { + t.Lock() + defer t.Unlock() + t.stopped = true +} + +func (t *fakeTicker) fire(now time.Time) { + t.Lock() + defer t.Unlock() + // Slight deviation from the stdlib ticker: time.Ticker will + // adjust t.next to make up for missed ticks, whereas we tick on a + // fixed interval regardless of receiver behavior. In our case + // this is fine, since we're using the ticker as a wakeup + // mechanism and not a precise timekeeping system. + select { + case t.ch <- now: + default: + } + t.next = now.Add(t.interval) +} + +type fakeTime struct { + sync.Mutex + *sync.Cond + curTime time.Time + tickers []*fakeTicker +} + +func newFakeTime() *fakeTime { + ret := &fakeTime{ + curTime: epoch, + } + ret.Cond = &sync.Cond{L: &ret.Mutex} + ret.Advance(time.Duration(1)) // so that Now never IsZero + return ret +} + +func (t *fakeTime) Now() time.Time { + t.Lock() + defer t.Unlock() + ret := t.curTime + // so that time always seems to advance for the program under test + t.curTime = t.curTime.Add(time.Microsecond) + return ret +} + +func (t *fakeTime) NewTicker(d time.Duration) ticker { + t.Lock() + defer t.Unlock() + ret := &fakeTicker{ + ch: make(chan time.Time, 1), + interval: d, + next: t.curTime.Add(d), + } + t.tickers = append(t.tickers, ret) + t.Cond.Broadcast() + return ret +} + +func (t *fakeTime) Advance(d time.Duration) { + t.Lock() + defer t.Unlock() + t.curTime = t.curTime.Add(d) + for _, tick := range t.tickers { + if t.curTime.After(tick.next) { + tick.fire(t.curTime) + } + } +} + +func waitExpInt(t *testing.T, p *Prober, path string, want int) { + t.Helper() + err := tstest.WaitFor(convergenceTimeout, func() error { + got, ok := getExpInt(t, p, path) + if !ok { + return fmt.Errorf("expvar %q did not get set", path) + } + if got != want { + return fmt.Errorf("expvar %q is %d, want %d", path, got, want) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func getExpInt(t *testing.T, p *Prober, path string) (ret int, ok bool) { + t.Helper() + s := p.Expvar().String() + dec := map[string]interface{}{} + if err := json.Unmarshal([]byte(s), &dec); err != nil { + t.Fatalf("couldn't unmarshal expvar data: %v", err) + } + var v interface{} = dec + for _, d := range strings.Split(path, "/") { + m, ok := v.(map[string]interface{}) + if !ok { + t.Fatalf("expvar path %q ended early with a leaf value", path) + } + child, ok := m[d] + if !ok { + return 0, false + } + v = child + } + f, ok := v.(float64) + if !ok { + return 0, false + } + return int(f), true +} + +func waitActiveProbes(t *testing.T, p *Prober, want int) { + t.Helper() + err := tstest.WaitFor(convergenceTimeout, func() error { + if got := p.activeProbes(); got != want { + return fmt.Errorf("active probe count is %d, want %d", got, want) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/prober/tcp.go b/prober/tcp.go new file mode 100644 index 000000000..84ff76037 --- /dev/null +++ b/prober/tcp.go @@ -0,0 +1,30 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package prober + +import ( + "context" + "fmt" + "net" +) + +// TCP returns a Probe that healthchecks a TCP endpoint. +// +// The Probe reports whether it can successfully connect to addr. +func TCP(addr string) Probe { + return func(ctx context.Context) error { + return probeTCP(ctx, addr) + } +} + +func probeTCP(ctx context.Context, addr string) error { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dialing %q: %v", addr, err) + } + conn.Close() + return nil +} diff --git a/prober/tls.go b/prober/tls.go new file mode 100644 index 000000000..285be6935 --- /dev/null +++ b/prober/tls.go @@ -0,0 +1,46 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package prober + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" +) + +// TLS returns a Probe that healthchecks a TLS endpoint. +// +// The Probe connects to hostname, does a TLS handshake, verifies that +// the hostname matches the presented certificate, and that the +// certificate expires in more than 7 days from the probe time. +func TLS(hostname string) Probe { + return func(ctx context.Context) error { + return probeTLS(ctx, hostname) + } +} + +func probeTLS(ctx context.Context, hostname string) error { + var d net.Dialer + conn, err := tls.DialWithDialer(&d, "tcp", hostname+":443", nil) + if err != nil { + return fmt.Errorf("connecting to %q: %w", hostname, err) + } + if err := conn.Handshake(); err != nil { + return fmt.Errorf("TLS handshake error with %q: %w", hostname, err) + } + if err := conn.VerifyHostname(hostname); err != nil { + return fmt.Errorf("Host %q TLS verification failed: %w", hostname, err) + } + + latestAllowedExpiration := time.Now().Add(7 * 24 * time.Hour) // 7 days from now + if expires := conn.ConnectionState().PeerCertificates[0].NotAfter; latestAllowedExpiration.After(expires) { + left := expires.Sub(time.Now()) + return fmt.Errorf("TLS certificate for %q expires in %v", hostname, left) + } + + return nil +}