diff --git a/ipn/local.go b/ipn/local.go index 40e9943e6..99878fdd4 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -5,6 +5,7 @@ package ipn import ( + "context" "errors" "fmt" "log" @@ -27,6 +28,8 @@ import ( // plane and the local network stack, wiring up NetworkMap updates // from the cloud to the local WireGuard engine. type LocalBackend struct { + ctx context.Context // valid until Close + ctxCancel context.CancelFunc // closes ctx logf logger.Logf e wgengine.Engine store StateStore @@ -66,12 +69,15 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin // Default filter blocks everything, until Start() is called. e.SetFilter(filter.NewAllowNone()) + ctx, cancel := context.WithCancel(context.Background()) portpoll, err := portlist.NewPoller() if err != nil { logf("skipping portlist: %s\n", err) } b := &LocalBackend{ + ctx: ctx, + ctxCancel: cancel, logf: logf, e: e, store: store, @@ -84,7 +90,7 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin e.SetNetInfoCallback(b.SetNetInfo) if b.portpoll != nil { - go b.portpoll.Run() + go b.portpoll.Run(ctx) go b.runPoller() } @@ -92,9 +98,7 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin } func (b *LocalBackend) Shutdown() { - if b.portpoll != nil { - b.portpoll.Close() - } + b.ctxCancel() b.c.Shutdown() b.e.Close() b.e.Wait() @@ -313,9 +317,9 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) { func (b *LocalBackend) runPoller() { for { - ports := <-b.portpoll.C - if ports == nil { - break + ports, ok := <-b.portpoll.C + if !ok { + return } sl := []tailcfg.Service{} for _, p := range ports { diff --git a/portlist/netstat_test.go b/portlist/netstat_test.go index e39909e7f..3125e360b 100644 --- a/portlist/netstat_test.go +++ b/portlist/netstat_test.go @@ -5,7 +5,6 @@ package portlist import ( - "fmt" "testing" ) @@ -62,28 +61,26 @@ udp4 0 0 *.5553 *.* ` func TestParsePortsNetstat(t *testing.T) { - expect := List{ + want := List{ Port{"tcp", 22, "", ""}, Port{"tcp", 23, "", ""}, Port{"tcp", 24, "", ""}, - Port{"tcp", 32, "", "sshd"}, - Port{"udp", 53, "", "chrome"}, - Port{"udp", 53, "", "funball"}, - Port{"udp", 5050, "", "CDPSvc"}, + Port{"tcp", 32, "sshd", ""}, + Port{"udp", 53, "chrome", ""}, + Port{"udp", 53, "funball", ""}, + Port{"udp", 5050, "CDPSvc", ""}, Port{"udp", 5353, "", ""}, Port{"udp", 5354, "", ""}, Port{"udp", 5453, "", ""}, Port{"udp", 5553, "", ""}, - Port{"udp", 9353, "", "iTunes"}, + Port{"udp", 9353, "iTunes", ""}, } pl := parsePortsNetstat(netstat_output) - fmt.Printf("--- expect:\n%v\n", expect) - fmt.Printf("--- got:\n%v\n", pl) for i := range pl { - if expect[i] != pl[i] { - t.Fatalf("row#%d\n expect=%v\n got=%v\n", - i, expect[i], pl[i]) + if pl[i] != want[i] { + t.Errorf("row#%d\n got: %#v\n\nwant: %#v\n", + i, pl[i], want[i]) } } } diff --git a/portlist/poller.go b/portlist/poller.go index d6fd7036e..5dd21476e 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -5,40 +5,67 @@ package portlist import ( + "context" "time" ) +// Poller scans the systems for listening ports periodically and sends +// the results to C. type Poller struct { - C chan List // new data when it arrives; closed when done + // C received the list of ports periodically. It's closed when + // Run completes, after which Err can be checked. + C <-chan List + + c chan List + + // Err is the error from the final GetList call. It is only + // valid to read once C has been closed. Err is nil if Close + // is called or the context is canceled. + Err error + quitCh chan struct{} // close this to force exit - Err error // last returned error code, if any prev List // most recent data } +// NewPoller returns a new portlist Poller. It returns an error +// if the portlist couldn't be obtained. Subsequent func NewPoller() (*Poller, error) { p := &Poller{ - C: make(chan List), + c: make(chan List), quitCh: make(chan struct{}), } - // Do one initial poll synchronously, so the caller can react - // to any obvious errors. - p.prev, p.Err = GetList(nil) - return p, p.Err + p.C = p.c + + // Do one initial poll synchronously so we can return an error + // early. + var err error + p.prev, err = GetList(nil) + if err != nil { + return nil, err + } + return p, nil } -func (p *Poller) Close() { +func (p *Poller) Close() error { + select { + case <-p.quitCh: + return nil + default: + } close(p.quitCh) <-p.C + return nil } -// Poll periodically. Run this in a goroutine if you want. -func (p *Poller) Run() error { - defer close(p.C) - tick := time.NewTicker(POLL_SECONDS * time.Second) +// Run runs the Poller periodically until either the context +// is done, or the Close is called. +func (p *Poller) Run(ctx context.Context) error { + defer close(p.c) + tick := time.NewTicker(pollInterval) defer tick.Stop() // Send out the pre-generated initial value - p.C <- p.prev + p.c <- p.prev for { select { @@ -46,12 +73,21 @@ func (p *Poller) Run() error { pl, err := GetList(p.prev) if err != nil { p.Err = err - return p.Err + return err } - if !pl.SameInodes(p.prev) { - p.prev = pl - p.C <- pl + if pl.SameInodes(p.prev) { + continue } + p.prev = pl + select { + case p.c <- pl: + case <-ctx.Done(): + return ctx.Err() + case <-p.quitCh: + return nil + } + case <-ctx.Done(): + return ctx.Err() case <-p.quitCh: return nil } diff --git a/portlist/portlist.go b/portlist/portlist.go index 99e04891a..883f373f7 100644 --- a/portlist/portlist.go +++ b/portlist/portlist.go @@ -9,13 +9,16 @@ import ( "strings" ) +// Port is a listening port on the machine. type Port struct { - Proto string - Port uint16 - inode string - Process string + Proto string // "tcp" or "udp" + Port uint16 // port number + Process string // optional process name, if found + + inode string // OS-specific; "socket:[165614651]" on Linux } +// List is a list of Ports. type List []Port var protos = []string{"tcp", "udp"} @@ -62,12 +65,12 @@ func (a List) SameInodes(b List) bool { } func (pl List) String() string { - out := []string{} + var sb strings.Builder for _, v := range pl { - out = append(out, fmt.Sprintf("%-3s %5d %-17s %#v", - v.Proto, v.Port, v.inode, v.Process)) + fmt.Fprintf(&sb, "%-3s %5d %-17s %#v\n", + v.Proto, v.Port, v.inode, v.Process) } - return strings.Join(out, "\n") + return strings.TrimRight(sb.String(), "\n") } func GetList(prev List) (List, error) { diff --git a/portlist/portlist_darwin.go b/portlist/portlist_darwin.go index 767f4f7e4..e1e385d0c 100644 --- a/portlist/portlist_darwin.go +++ b/portlist/portlist_darwin.go @@ -13,12 +13,13 @@ import ( "log" "os" "strings" + "time" exec "tailscale.com/tempfork/osexec" ) // We have to run netstat, which is a bit expensive, so don't do it too often. -const POLL_SECONDS = 5 +const pollInterval = 5 * time.Second func listPorts() (List, error) { return listPortsNetstat("-na") diff --git a/portlist/portlist_linux.go b/portlist/portlist_linux.go index 53e03ee56..959a65ed9 100644 --- a/portlist/portlist_linux.go +++ b/portlist/portlist_linux.go @@ -13,10 +13,13 @@ import ( "sort" "strconv" "strings" + "time" + + "golang.org/x/sys/unix" ) // Reading the sockfiles on Linux is very fast, so we can do it often. -const POLL_SECONDS = 1 +const pollInterval = 1 * time.Second // TODO(apenwarr): Include IPv6 ports eventually. // Right now we don't route IPv6 anyway so it's better to exclude them. @@ -82,24 +85,73 @@ func listPorts() (List, error) { } func addProcesses(pl []Port) ([]Port, error) { - pm := map[string]*Port{} - for k := range pl { - pm[pl[k].inode] = &pl[k] + pm := map[string]*Port{} // by Port.inode + for i := range pl { + pm[pl[i].inode] = &pl[i] } + err := foreachPID(func(pid string) error { + fdDir, err := os.Open(fmt.Sprintf("/proc/%s/fd", pid)) + if err != nil { + // Can't open fd list for this pid. Maybe + // don't have access. Ignore it. + return nil + } + defer fdDir.Close() + + targetBuf := make([]byte, 64) // plenty big for "socket:[165614651]" + for { + fds, err := fdDir.Readdirnames(100) + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("readdir: %w", err) + } + for _, fd := range fds { + n, err := unix.Readlink(fmt.Sprintf("/proc/%s/fd/%s", pid, fd), targetBuf) + if err != nil { + // Not a symlink or no permission. + // Skip it. + continue + } + + // TODO(apenwarr): use /proc/*/cmdline instead of /comm? + // Unsure right now whether users will want the extra detail + // or not. + pe := pm[string(targetBuf[:n])] // m[string([]byte)] avoids alloc + if pe != nil { + comm, err := ioutil.ReadFile(fmt.Sprintf("/proc/%s/comm", pid)) + if err != nil { + // Usually shouldn't happen. One possibility is + // the process has gone away, so let's skip it. + continue + } + pe.Process = strings.TrimSpace(string(comm)) + } + } + } + }) + if err != nil { + return nil, err + } + return pl, nil +} + +func foreachPID(fn func(pidStr string) error) error { pdir, err := os.Open("/proc") if err != nil { - return nil, fmt.Errorf("/proc: %s", err) + return err } defer pdir.Close() for { pids, err := pdir.Readdirnames(100) if err == io.EOF { - break + return nil } if err != nil { - return nil, fmt.Errorf("/proc: %s", err) + return err } for _, pid := range pids { @@ -109,47 +161,9 @@ func addProcesses(pl []Port) ([]Port, error) { // /proc has lots of non-pid stuff in it. continue } - fddir, err := os.Open(fmt.Sprintf("/proc/%s/fd", pid)) - if err != nil { - // Can't open fd list for this pid. Maybe - // don't have access. Ignore it. - continue - } - defer fddir.Close() - - for { - fds, err := fddir.Readdirnames(100) - if err == io.EOF { - break - } - if err != nil { - return nil, fmt.Errorf("readdir: %s", err) - } - for _, fd := range fds { - target, err := os.Readlink(fmt.Sprintf("/proc/%s/fd/%s", pid, fd)) - if err != nil { - // Not a symlink or no permission. - // Skip it. - continue - } - - // TODO(apenwarr): use /proc/*/cmdline instead of /comm? - // Unsure right now whether users will want the extra detail - // or not. - pe := pm[target] - if pe != nil { - comm, err := ioutil.ReadFile(fmt.Sprintf("/proc/%s/comm", pid)) - if err != nil { - // Usually shouldn't happen. One possibility is - // the process has gone away, so let's skip it. - continue - } - pe.Process = strings.TrimSpace(string(comm)) - } - } + if err := fn(pid); err != nil { + return err } } } - - return pl, nil } diff --git a/portlist/portlist_other.go b/portlist/portlist_other.go index ffcbecb0b..2ac61f0df 100644 --- a/portlist/portlist_other.go +++ b/portlist/portlist_other.go @@ -6,8 +6,10 @@ package portlist +import "time" + // We have to run netstat, which is a bit expensive, so don't do it too often. -const POLL_SECONDS = 5 +const pollInterval = 5 * time.Second func listPorts() (List, error) { return listPortsNetstat("-na") diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go new file mode 100644 index 000000000..4b95f8d6f --- /dev/null +++ b/portlist/portlist_test.go @@ -0,0 +1,28 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import "testing" + +func TestGetList(t *testing.T) { + pl, err := GetList(nil) + if err != nil { + t.Fatal(err) + } + for i, p := range pl { + t.Logf("[%d] %+v", i, p) + } + t.Logf("As String: %v", pl.String()) +} + +func BenchmarkGetList(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := GetList(nil) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go index 8e2885d25..2e77025c6 100644 --- a/portlist/portlist_windows.go +++ b/portlist/portlist_windows.go @@ -4,8 +4,10 @@ package portlist +import "time" + // Forking on Windows is insanely expensive, so don't do it too often. -const POLL_SECONDS = 5 +const pollInterval = 5 * time.Second func listPorts() (List, error) { return listPortsNetstat("-na")