diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 047fefb2b..6b138414c 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -150,7 +150,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "server has no local backend", http.StatusInternalServerError) return } - if r.Referer() != "" || r.Header.Get("Origin") != "" || !validHost(r.Host) { + if r.Referer() != "" || r.Header.Get("Origin") != "" || !h.validHost(r.Host) { metricInvalidRequests.Add(1) http.Error(w, "invalid localapi request", http.StatusForbidden) return @@ -180,21 +180,20 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -// validLocalHost allows either localhost or loopback IP hosts on platforms -// that use token security. -var validLocalHost = runtime.GOOS == "darwin" || runtime.GOOS == "ios" || runtime.GOOS == "android" +// validLocalHostForTesting allows loopback handlers without RequiredPassword for testing. +var validLocalHostForTesting = false // validHost reports whether h is a valid Host header value for a LocalAPI request. -func validHost(h string) bool { +func (h *Handler) validHost(hostname string) bool { // The client code sends a hostname of "local-tailscaled.sock". - switch h { + switch hostname { case "", apitype.LocalAPIHost: return true } - if !validLocalHost { - return false + if !validLocalHostForTesting && h.RequiredPassword == "" { + return false // only allow localhost with basic auth or in tests } - host, _, err := net.SplitHostPort(h) + host, _, err := net.SplitHostPort(hostname) if err != nil { return false } diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 1fdc3874d..8d7d317b0 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -23,9 +23,9 @@ func TestValidHost(t *testing.T) { }{ {"", true}, {apitype.LocalAPIHost, true}, - {"localhost:9109", validLocalHost}, - {"127.0.0.1:9110", validLocalHost}, - {"[::1]:9111", validLocalHost}, + {"localhost:9109", false}, + {"127.0.0.1:9110", false}, + {"[::1]:9111", false}, {"100.100.100.100:41112", false}, {"10.0.0.1:41112", false}, {"37.16.9.210:41112", false}, @@ -33,7 +33,8 @@ func TestValidHost(t *testing.T) { for _, test := range tests { t.Run(test.host, func(t *testing.T) { - if got := validHost(test.host); got != test.valid { + h := &Handler{} + if got := h.validHost(test.host); got != test.valid { t.Errorf("validHost(%q)=%v, want %v", test.host, got, test.valid) } }) @@ -41,10 +42,9 @@ func TestValidHost(t *testing.T) { } func TestSetPushDeviceToken(t *testing.T) { - origValidLocalHost := validLocalHost - validLocalHost = true + validLocalHostForTesting = true defer func() { - validLocalHost = origValidLocalHost + validLocalHostForTesting = false }() h := &Handler{ diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 0c33ee1c3..92e595d6c 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -8,6 +8,8 @@ package tsnet import ( "context" + crand "crypto/rand" + "encoding/hex" "errors" "fmt" "io" @@ -88,19 +90,22 @@ type Server struct { // If empty, the Tailscale default is used. ControlURL string - initOnce sync.Once - initErr error - lb *ipnlocal.LocalBackend - netstack *netstack.Impl - linkMon *monitor.Mon - localAPIListener net.Listener - rootPath string // the state directory - hostname string - shutdownCtx context.Context - shutdownCancel context.CancelFunc - localClient *tailscale.LocalClient - logbuffer *filch.Filch - logtail *logtail.Logger + initOnce sync.Once + initErr error + lb *ipnlocal.LocalBackend + netstack *netstack.Impl + linkMon *monitor.Mon + rootPath string // the state directory + hostname string + shutdownCtx context.Context + shutdownCancel context.CancelFunc + localAPICred string // basic auth password for localAPITCPListener + localAPITCPListener net.Listener // optional loopback, restricted to PID + localAPIListener net.Listener // in-memory, used by localClient + localClient *tailscale.LocalClient // in-memory + logbuffer *filch.Filch + logtail *logtail.Logger + logid string mu sync.Mutex listeners map[listenKey]*listener @@ -139,6 +144,64 @@ func (s *Server) LocalClient() (*tailscale.LocalClient, error) { return s.localClient, nil } +// LoopbackLocalAPI returns a loopback ip:port listening for the "LocalAPI". +// +// As the LocalAPI is powerful, access to endpoints requires BOTH passing a +// "Sec-Tailscale: localapi" HTTP header and passing cred as a basic auth. +// +// It will start the server and the local client listener if they have not +// been started yet. +// +// If you only need to use the LocalAPI from Go, then prefer LocalClient +// as it does not require communication via TCP. +func (s *Server) LoopbackLocalAPI() (addr string, cred string, err error) { + if err := s.Start(); err != nil { + return "", "", err + } + + if s.localAPITCPListener == nil { + var cred [16]byte + if _, err := crand.Read(cred[:]); err != nil { + return "", "", err + } + s.localAPICred = hex.EncodeToString(cred[:]) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", "", err + } + s.localAPITCPListener = ln + + go func() { + lah := localapi.NewHandler(s.lb, s.logf, s.logid) + lah.PermitWrite = true + lah.PermitRead = true + lah.RequiredPassword = s.localAPICred + h := &localSecHandler{h: lah, cred: s.localAPICred} + + if err := http.Serve(s.localAPITCPListener, h); err != nil { + s.logf("localapi tcp serve error: %v", err) + } + }() + } + + return s.localAPITCPListener.Addr().String(), s.localAPICred, nil +} + +type localSecHandler struct { + h http.Handler + cred string +} + +func (h *localSecHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Sec-Tailscale") != "localapi" { + w.WriteHeader(403) + io.WriteString(w, "missing 'Sec-Tailscale: localapi' header") + return + } + h.h.ServeHTTP(w, r) +} + // Start connects the server to the tailnet. // Optional: any calls to Dial/Listen will also call Start. func (s *Server) Start() error { @@ -240,6 +303,9 @@ func (s *Server) Close() error { if s.localAPIListener != nil { s.localAPIListener.Close() } + if s.localAPITCPListener != nil { + s.localAPITCPListener.Close() + } s.mu.Lock() defer s.mu.Unlock() @@ -325,7 +391,7 @@ func (s *Server) start() (reterr error) { if err := lpc.Validate(logtail.CollectionNode); err != nil { return fmt.Errorf("logpolicy.Config.Validate for %v: %w", cfgPath, err) } - logid := lpc.PublicID.String() + s.logid = lpc.PublicID.String() s.logbuffer, err = filch.New(filepath.Join(s.rootPath, "tailscaled"), filch.Options{ReplaceStderr: false}) if err != nil { @@ -399,7 +465,7 @@ func (s *Server) start() (reterr error) { if s.Ephemeral { loginFlags = controlclient.LoginEphemeral } - lb, err := ipnlocal.NewLocalBackend(logf, logid, s.Store, s.dialer, eng, loginFlags) + lb, err := ipnlocal.NewLocalBackend(logf, s.logid, s.Store, s.dialer, eng, loginFlags) if err != nil { return fmt.Errorf("NewLocalBackend: %v", err) } @@ -435,7 +501,7 @@ func (s *Server) start() (reterr error) { go s.printAuthURLLoop() // Run the localapi handler, to allow fetching LetsEncrypt certs. - lah := localapi.NewHandler(lb, logf, logid) + lah := localapi.NewHandler(lb, logf, s.logid) lah.PermitWrite = true lah.PermitRead = true diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 9d0fa8700..3847b4358 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -9,16 +9,17 @@ import ( "flag" "fmt" "io" - "path/filepath" - "os" + "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "time" "tailscale.com/ipn/store/mem" + "tailscale.com/net/netns" "tailscale.com/tailcfg" "tailscale.com/tstest/integration" - "tailscale.com/net/netns" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/logger" ) @@ -63,7 +64,7 @@ func TestListenerPort(t *testing.T) { var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") -func TestConn(t *testing.T) { +func startControl(t *testing.T) (controlURL string) { // Corp#4520: don't use netns for tests. netns.SetEnabled(false) t.Cleanup(func() { @@ -81,14 +82,19 @@ func TestConn(t *testing.T) { control.HTTPTestServer = httptest.NewUnstartedServer(control) control.HTTPTestServer.Start() t.Cleanup(control.HTTPTestServer.Close) - controlURL := control.HTTPTestServer.URL + controlURL = control.HTTPTestServer.URL t.Logf("testcontrol listening on %s", controlURL) + return controlURL +} + +func TestConn(t *testing.T) { + controlURL := startControl(t) tmp := t.TempDir() tmps1 := filepath.Join(tmp, "s1") os.MkdirAll(tmps1, 0755) s1 := &Server{ - Dir: tmps1, + Dir: tmps1, ControlURL: controlURL, Hostname: "s1", Store: new(mem.Store), @@ -99,7 +105,7 @@ func TestConn(t *testing.T) { tmps2 := filepath.Join(tmp, "s1") os.MkdirAll(tmps2, 0755) s2 := &Server{ - Dir: tmps2, + Dir: tmps2, ControlURL: controlURL, Hostname: "s2", Store: new(mem.Store), @@ -167,3 +173,88 @@ func TestConn(t *testing.T) { t.Errorf("got %q, want %q", got, want) } } + +func TestLoopbackLocalAPI(t *testing.T) { + controlURL := startControl(t) + + tmp := t.TempDir() + tmps1 := filepath.Join(tmp, "s1") + os.MkdirAll(tmps1, 0755) + s1 := &Server{ + Dir: tmps1, + ControlURL: controlURL, + Hostname: "s1", + Store: new(mem.Store), + Ephemeral: true, + } + defer s1.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if _, err := s1.Up(ctx); err != nil { + t.Fatal(err) + } + + addr, cred, err := s1.LoopbackLocalAPI() + if err != nil { + t.Fatal(err) + } + + url := "http://" + addr + "/localapi/v0/status" + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + t.Fatal(err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 403 { + t.Errorf("GET %s returned %d, want 403 without Sec- header", url, res.StatusCode) + } + + req, err = http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Sec-Tailscale", "localapi") + res, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 401 { + t.Errorf("GET %s returned %d, want 401 without basic auth", url, res.StatusCode) + } + + req, err = http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + t.Fatal(err) + } + req.SetBasicAuth("", cred) + res, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 403 { + t.Errorf("GET %s returned %d, want 403 without Sec- header", url, res.StatusCode) + } + + req, err = http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Sec-Tailscale", "localapi") + req.SetBasicAuth("", cred) + res, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 200 { + t.Errorf("GET /status returned %d, want 200", res.StatusCode) + } +}