wgengine/monitor: fix memory corruption in Windows implementation

I used the Windows APIs wrong previously, but it had worked just
enough.

Updates #921

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-11-18 12:32:38 -08:00 committed by Brad Fitzpatrick
parent 8f76548fd9
commit eccc167733
1 changed files with 115 additions and 57 deletions

View File

@ -7,6 +7,8 @@ package monitor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"runtime"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -24,9 +26,15 @@ const (
) )
var ( var (
iphlpapi = syscall.NewLazyDLL("iphlpapi.dll") iphlpapi = syscall.NewLazyDLL("iphlpapi.dll")
notifyAddrChangeProc = iphlpapi.NewProc("NotifyAddrChange") notifyAddrChangeProc = iphlpapi.NewProc("NotifyAddrChange")
notifyRouteChangeProc = iphlpapi.NewProc("NotifyRouteChange") notifyRouteChangeProc = iphlpapi.NewProc("NotifyRouteChange")
cancelIPChangeNotifyProc = iphlpapi.NewProc("CancelIPChangeNotify")
)
const (
_STATUS_PENDING = 0x00000103 // 259
_STATUS_WAIT_0 = 0
) )
type unspecifiedMessage struct{} type unspecifiedMessage struct{}
@ -43,27 +51,33 @@ type messageOrError struct {
} }
type winMon struct { type winMon struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
messagec chan messageOrError messagec chan messageOrError
logf logger.Logf logf logger.Logf
pollTicker *time.Ticker pollTicker *time.Ticker
lastState *interfaces.State lastState *interfaces.State
closeHandle windows.Handle // signaled upon close
mu sync.Mutex mu sync.Mutex
event windows.Handle
lastNetChange time.Time lastNetChange time.Time
inFastPoll bool // recent net change event made us go into fast polling mode (to detect proxy changes) inFastPoll bool // recent net change event made us go into fast polling mode (to detect proxy changes)
} }
func newOSMon(logf logger.Logf) (osMon, error) { func newOSMon(logf logger.Logf) (osMon, error) {
closeHandle, err := windows.CreateEvent(nil, 1 /* manual reset */, 0 /* unsignaled */, nil /* no name */)
if err != nil {
return nil, fmt.Errorf("CreateEvent: %w", err)
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
m := &winMon{ m := &winMon{
logf: logf, logf: logf,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
messagec: make(chan messageOrError, 1), messagec: make(chan messageOrError, 1),
pollTicker: time.NewTicker(pollIntervalSlow), pollTicker: time.NewTicker(pollIntervalSlow),
closeHandle: closeHandle,
} }
go m.awaitIPAndRouteChanges() go m.awaitIPAndRouteChanges()
return m, nil return m, nil
@ -72,14 +86,7 @@ func newOSMon(logf logger.Logf) (osMon, error) {
func (m *winMon) Close() error { func (m *winMon) Close() error {
m.cancel() m.cancel()
m.pollTicker.Stop() m.pollTicker.Stop()
windows.SetEvent(m.closeHandle) // wakes up any reader blocked in Receive
m.mu.Lock()
defer m.mu.Unlock()
if h := m.event; h != 0 {
// Wake up any reader blocked in Receive.
windows.SetEvent(h)
}
return nil return nil
} }
@ -136,52 +143,80 @@ func (m *winMon) getIPOrRouteChangeMessage() (message, error) {
return nil, errClosed return nil, errClosed
} }
var o windows.Overlapped // TODO(bradfitz): locking ourselves to an OS thread here
h, err := windows.CreateEvent(nil, 1 /* true*/, 0 /* unsignaled */, nil /* no name */) // likely isn't necessary, but also can't really hurt.
if err != nil { // We'll be blocked in windows.WaitForMultipleObjects below
m.logf("CreateEvent: %v", err) // anyway, so might as well stay on this thread during the
return nil, err // notify calls and cancel funcs.
} // Given the past memory corruption from misuse of these APIs,
defer windows.CloseHandle(h) // and my continued lack of understanding of Windows APIs,
// I'll be paranoid. But perhaps we can remove this once
// we understand more.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
m.mu.Lock() addrHandle, oaddr, cancel, err := notifyAddrChange()
m.event = h
m.mu.Unlock()
o.HEvent = h
err = notifyAddrChange(&h, &o)
if err != nil { if err != nil {
m.logf("notifyAddrChange: %v", err) m.logf("notifyAddrChange: %v", err)
return nil, err return nil, err
} }
err = notifyRouteChange(&h, &o) defer cancel()
routeHandle, oroute, cancel, err := notifyRouteChange()
if err != nil { if err != nil {
m.logf("notifyRouteChange: %v", err) m.logf("notifyRouteChange: %v", err)
return nil, err return nil, err
} }
defer cancel()
t0 := time.Now() t0 := time.Now()
_, err = windows.WaitForSingleObject(o.HEvent, windows.INFINITE) eventNum, err := windows.WaitForMultipleObjects([]windows.Handle{
if m.ctx.Err() != nil { m.closeHandle, // eventNum 0
addrHandle, // eventNum 1
routeHandle, // eventNum 2
}, false, windows.INFINITE)
if m.ctx.Err() != nil || (err == nil && eventNum == 0) {
return nil, errClosed return nil, errClosed
} }
if err != nil { if err != nil {
m.logf("waitForSingleObject: %v", err) m.logf("waitForMultipleObjects: %v", err)
return nil, err return nil, err
} }
d := time.Since(t0) d := time.Since(t0)
m.logf("got windows change event after %v", d) var eventStr string
// notifyAddrChange and notifyRouteChange both seem to return the same
// handle value. Determine which fired by looking at the "Internal" (sic)
// field of the Ovelapped instead.
// TODO(bradfitz): maybe clean this up; see TODO in callNotifyProc.
if (eventNum == 1 || eventNum == 2) && addrHandle == routeHandle {
if oaddr.Internal == _STATUS_WAIT_0 && oroute.Internal == _STATUS_PENDING {
eventStr = "addr-o" // "-o" overlapped suffix to distinguish from "addr" below
} else if oroute.Internal == _STATUS_WAIT_0 && oaddr.Internal == _STATUS_PENDING {
eventStr = "route-o"
} else {
eventStr = fmt.Sprintf("[unexpected] addr.internal=%d; route.internal=%d", oaddr.Internal, oroute.Internal)
}
} else {
switch eventNum {
case 1:
eventStr = "addr"
case 2:
eventStr = "route"
default:
eventStr = fmt.Sprintf("%d [unexpected]", eventNum)
}
}
m.logf("got windows change event after %v: evt=%s", d, eventStr)
m.mu.Lock() m.mu.Lock()
{ {
m.lastNetChange = time.Now() m.lastNetChange = time.Now()
m.event = 0
// Something changed, so assume Windows is about to // Something changed, so assume Windows is about to
// discover its new proxy settings from WPAD, which // discover its new proxy settings from WPAD, which
// seems to take a bit. Poll heavily for awhile. // seems to take a bit. Poll heavily for awhile.
m.logf("starting quick poll, waiting for WPAD change")
m.inFastPoll = true m.inFastPoll = true
m.pollTicker.Reset(pollIntervalFast) m.pollTicker.Reset(pollIntervalFast)
} }
@ -190,23 +225,46 @@ func (m *winMon) getIPOrRouteChangeMessage() (message, error) {
return unspecifiedMessage{}, nil return unspecifiedMessage{}, nil
} }
func notifyAddrChange(h *windows.Handle, o *windows.Overlapped) error { func notifyAddrChange() (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
return callNotifyProc(notifyAddrChangeProc, h, o) return callNotifyProc(notifyAddrChangeProc)
} }
func notifyRouteChange(h *windows.Handle, o *windows.Overlapped) error { func notifyRouteChange() (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
return callNotifyProc(notifyRouteChangeProc, h, o) return callNotifyProc(notifyRouteChangeProc)
} }
func callNotifyProc(p *syscall.LazyProc, h *windows.Handle, o *windows.Overlapped) error { func callNotifyProc(p *syscall.LazyProc) (h windows.Handle, o *windows.Overlapped, cancel func(), err error) {
r1, _, e1 := p.Call(uintptr(unsafe.Pointer(h)), uintptr(unsafe.Pointer(o))) o = new(windows.Overlapped)
expect := uintptr(0)
if h != nil || o != nil { // TODO(bradfitz): understand why this if-false code doesn't
const ERROR_IO_PENDING = 997 // work, even though the docs online suggest we should pass an
expect = ERROR_IO_PENDING // event in the overlapped.Hevent field.
// The docs at
// https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-overlapped
// says that o.HEvent can be zero, though, which seems to work.
// Note that the returned windows.Handle returns the same value for both
// notifyAddrChange and notifyRouteChange, which is why our caller needs
// to look at the returned Overlapped's Internal field to see which case
// fired. That's also worth understanding more.
// See crawshaw's comment at https://github.com/tailscale/tailscale/pull/944#discussion_r526469186
// too.
if false {
evt, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, nil, nil, err
}
o.HEvent = evt
} }
if r1 == expect {
return nil r1, _, e1 := syscall.Syscall(p.Addr(), 2, uintptr(unsafe.Pointer(&h)), uintptr(unsafe.Pointer(o)), 0)
// We expect ERROR_IO_PENDING.
if syscall.Errno(r1) != windows.ERROR_IO_PENDING {
return 0, nil, nil, e1
} }
return e1
cancel = func() {
cancelIPChangeNotifyProc.Call(uintptr(unsafe.Pointer(o)))
}
return h, o, cancel, nil
} }