ssh/tailssh: cache public keys fetched from URLs
Updates #3802 Change-Id: I96715bae02bce6ea19f16b1736d1bbcd7bcf3534 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
3ffd88a84a
commit
93221b4535
|
@ -53,10 +53,21 @@ type server struct {
|
|||
logf logger.Logf
|
||||
tailscaledPath string
|
||||
|
||||
// mu protects activeSessions.
|
||||
pubKeyHTTPClient *http.Client // or nil for http.DefaultClient
|
||||
timeNow func() time.Time // or nil for time.Now
|
||||
|
||||
// mu protects the following
|
||||
mu sync.Mutex
|
||||
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session
|
||||
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
|
||||
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
|
||||
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
|
||||
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
|
||||
}
|
||||
|
||||
func (srv *server) now() time.Time {
|
||||
if srv.timeNow != nil {
|
||||
return srv.timeNow()
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -264,7 +275,7 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
|
|||
return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr)
|
||||
}
|
||||
ci := &sshConnInfo{
|
||||
now: time.Now(),
|
||||
now: srv.now(),
|
||||
fetchPublicKeysURL: srv.fetchPublicKeysURL,
|
||||
sshUser: sshUser,
|
||||
src: remoteAddr,
|
||||
|
@ -280,11 +291,58 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
|
|||
return a, ci, localUser, nil
|
||||
}
|
||||
|
||||
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
|
||||
// "https://github.com/foo.keys")
|
||||
type pubKeyCacheEntry struct {
|
||||
lines []string
|
||||
etag string // if sent by server
|
||||
at time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys
|
||||
pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses
|
||||
)
|
||||
|
||||
func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
// Mostly don't care about the size of this cache. Clean rarely.
|
||||
if m := srv.fetchPublicKeysCache; len(m) > 50 {
|
||||
tooOld := srv.now().Add(pubKeyCacheDuration * 10)
|
||||
for k, ce := range m {
|
||||
if ce.at.Before(tooOld) {
|
||||
delete(m, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
ce, ok = srv.fetchPublicKeysCache[url]
|
||||
if !ok {
|
||||
return ce, false
|
||||
}
|
||||
maxAge := pubKeyCacheDuration
|
||||
if len(ce.lines) == 0 {
|
||||
maxAge = pubKeyCacheEmptyDuration
|
||||
}
|
||||
return ce, srv.now().Sub(ce.at) < maxAge
|
||||
}
|
||||
|
||||
func (srv *server) pubKeyClient() *http.Client {
|
||||
if srv.pubKeyHTTPClient != nil {
|
||||
return srv.pubKeyHTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
return nil, errors.New("invalid URL scheme")
|
||||
}
|
||||
// TODO(bradfitz): add caching
|
||||
|
||||
ce, ok := srv.fetchPublicKeysURLCached(url)
|
||||
if ok {
|
||||
return ce.lines, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
@ -292,16 +350,40 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if ce.etag != "" {
|
||||
req.Header.Add("If-None-Match", ce.etag)
|
||||
}
|
||||
res, err := srv.pubKeyClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, errors.New(res.Status)
|
||||
var lines []string
|
||||
var etag string
|
||||
switch res.StatusCode {
|
||||
default:
|
||||
err = fmt.Errorf("unexpected status %v", res.Status)
|
||||
srv.logf("fetching public keys from %s: %v", url, err)
|
||||
case http.StatusNotModified:
|
||||
lines = ce.lines
|
||||
etag = ce.etag
|
||||
case http.StatusOK:
|
||||
var all []byte
|
||||
all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||
if s := strings.TrimSpace(string(all)); s != "" {
|
||||
lines = strings.Split(s, "\n")
|
||||
}
|
||||
etag = res.Header.Get("Etag")
|
||||
}
|
||||
all, err := io.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||
return strings.Split(string(all), "\n"), err
|
||||
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
|
||||
at: srv.now(),
|
||||
lines: lines,
|
||||
etag: etag,
|
||||
})
|
||||
return lines, err
|
||||
}
|
||||
|
||||
// handleSSH is invoked when a new SSH connection attempt is made.
|
||||
|
@ -523,26 +605,20 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo
|
|||
func (srv *server) startSession(ss *sshSession) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
if srv.activeSessionByH == nil {
|
||||
srv.activeSessionByH = make(map[string]*sshSession)
|
||||
}
|
||||
if srv.activeSessionBySharedID == nil {
|
||||
srv.activeSessionBySharedID = make(map[string]*sshSession)
|
||||
}
|
||||
if ss.idH == "" {
|
||||
panic("empty idH")
|
||||
}
|
||||
if _, dup := srv.activeSessionByH[ss.idH]; dup {
|
||||
panic("dup idH")
|
||||
}
|
||||
if ss.sharedID == "" {
|
||||
panic("empty sharedID")
|
||||
}
|
||||
if _, dup := srv.activeSessionByH[ss.idH]; dup {
|
||||
panic("dup idH")
|
||||
}
|
||||
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
|
||||
panic("dup sharedID")
|
||||
}
|
||||
srv.activeSessionByH[ss.idH] = ss
|
||||
srv.activeSessionBySharedID[ss.sharedID] = ss
|
||||
mapSet(&srv.activeSessionByH, ss.idH, ss)
|
||||
mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss)
|
||||
}
|
||||
|
||||
// endSession unregisters s from the list of active sessions.
|
||||
|
@ -1057,3 +1133,11 @@ func envEq(a, b string) bool {
|
|||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
// mapSet assigns m[k] = v, making m if necessary.
|
||||
func mapSet[K comparable, V any](m *map[K]V, k K, v V) {
|
||||
if *m == nil {
|
||||
*m = make(map[K]V)
|
||||
}
|
||||
(*m)[k] = v
|
||||
}
|
||||
|
|
|
@ -9,13 +9,19 @@ package tailssh
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -25,6 +31,7 @@ import (
|
|||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tempfork/gliderlabs/ssh"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/lineread"
|
||||
|
@ -336,3 +343,63 @@ func parseEnv(out []byte) map[string]string {
|
|||
})
|
||||
return e
|
||||
}
|
||||
|
||||
func TestPublicKeyFetching(t *testing.T) {
|
||||
var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32((&reqsTotal), 1)
|
||||
etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
|
||||
w.Header().Set("Etag", etag)
|
||||
if v := r.Header.Get("If-None-Match"); v != "" {
|
||||
if v == etag {
|
||||
atomic.AddInt32(&reqsIfNoneMatchHit, 1)
|
||||
w.WriteHeader(304)
|
||||
return
|
||||
}
|
||||
atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
|
||||
}
|
||||
io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
|
||||
}))
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
keys := ts.URL
|
||||
|
||||
clock := &tstest.Clock{}
|
||||
srv := &server{
|
||||
pubKeyHTTPClient: ts.Client(),
|
||||
timeNow: clock.Now,
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
|
||||
t.Errorf("got %d requests; want %d", got, want)
|
||||
}
|
||||
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
|
||||
t.Errorf("got %d etag hits; want %d", got, want)
|
||||
}
|
||||
clock.Advance(5 * time.Minute)
|
||||
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %q; want %q", got, want)
|
||||
}
|
||||
if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
|
||||
t.Errorf("got %d requests; want %d", got, want)
|
||||
}
|
||||
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
|
||||
t.Errorf("got %d etag hits; want %d", got, want)
|
||||
}
|
||||
if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
|
||||
t.Errorf("got %d etag misses; want %d", got, want)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue