390 lines
9.4 KiB
Go
390 lines
9.4 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package router
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"sort"
|
|
"time"
|
|
"unsafe"
|
|
|
|
ole "github.com/go-ole/go-ole"
|
|
winipcfg "github.com/tailscale/winipcfg-go"
|
|
"github.com/tailscale/wireguard-go/conn"
|
|
"github.com/tailscale/wireguard-go/device"
|
|
"github.com/tailscale/wireguard-go/tun"
|
|
"golang.org/x/sys/windows"
|
|
"tailscale.com/wgengine/winnet"
|
|
)
|
|
|
|
const (
|
|
sockoptIP_UNICAST_IF = 31
|
|
sockoptIPV6_UNICAST_IF = 31
|
|
)
|
|
|
|
func htonl(val uint32) uint32 {
|
|
bytes := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(bytes, val)
|
|
return *(*uint32)(unsafe.Pointer(&bytes[0]))
|
|
}
|
|
|
|
func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
|
|
routes, err := winipcfg.GetRoutes(family)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
lowestMetric := ^uint32(0)
|
|
index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
|
|
luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
|
|
for _, route := range routes {
|
|
if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
|
|
continue
|
|
}
|
|
if route.Metric < lowestMetric {
|
|
lowestMetric = route.Metric
|
|
index = route.InterfaceIndex
|
|
luid = route.InterfaceLuid
|
|
}
|
|
}
|
|
if luid == *lastLuid {
|
|
return nil
|
|
}
|
|
*lastLuid = luid
|
|
if false {
|
|
bind, ok := device.Bind().(conn.BindSocketToInterface)
|
|
if !ok {
|
|
return fmt.Errorf("unexpected device.Bind type %T", device.Bind())
|
|
}
|
|
// TODO(apenwarr): doesn't work with magic socket yet.
|
|
if family == winipcfg.AF_INET {
|
|
return bind.BindSocketToInterface4(index, false)
|
|
} else if family == winipcfg.AF_INET6 {
|
|
return bind.BindSocketToInterface6(index, false)
|
|
}
|
|
} else {
|
|
log.Printf("WARNING: skipping windows socket binding.\n")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
|
|
guid := tun.GUID()
|
|
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
|
|
lastLuid4 := uint64(0)
|
|
lastLuid6 := uint64(0)
|
|
lastMtu := uint32(0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
doIt := func() error {
|
|
err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !autoMTU {
|
|
return nil
|
|
}
|
|
mtu := uint32(0)
|
|
if lastLuid4 != 0 {
|
|
iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if iface.Mtu > 0 {
|
|
mtu = iface.Mtu
|
|
}
|
|
}
|
|
if lastLuid6 != 0 {
|
|
iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if iface.Mtu > 0 && iface.Mtu < mtu {
|
|
mtu = iface.Mtu
|
|
}
|
|
}
|
|
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
|
|
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
iface.NlMtu = mtu - 80
|
|
// If the TUN device was created with a smaller MTU,
|
|
// though, such as 1280, we don't want to go bigger than
|
|
// configured. (See the comment on minimalMTU in the
|
|
// wgengine package.)
|
|
if min, err := tun.MTU(); err == nil && min < int(iface.NlMtu) {
|
|
iface.NlMtu = uint32(min)
|
|
}
|
|
if iface.NlMtu < 576 {
|
|
iface.NlMtu = 576
|
|
}
|
|
err = iface.Set()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
|
|
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
iface.NlMtu = mtu - 80
|
|
if iface.NlMtu < 1280 {
|
|
iface.NlMtu = 1280
|
|
}
|
|
err = iface.Set()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
lastMtu = mtu
|
|
}
|
|
return nil
|
|
}
|
|
err = doIt()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
|
|
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
|
|
if route.DestinationPrefix.PrefixLength == 0 {
|
|
_ = doIt()
|
|
}
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cb, nil
|
|
}
|
|
|
|
func setFirewall(ifcGUID *windows.GUID) (bool, error) {
|
|
c := ole.Connection{}
|
|
err := c.Initialize()
|
|
if err != nil {
|
|
return false, fmt.Errorf("c.Initialize: %v", err)
|
|
}
|
|
defer c.Uninitialize()
|
|
|
|
m, err := winnet.NewNetworkListManager(&c)
|
|
if err != nil {
|
|
return false, fmt.Errorf("winnet.NewNetworkListManager: %v", err)
|
|
}
|
|
defer m.Release()
|
|
|
|
cl, err := m.GetNetworkConnections()
|
|
if err != nil {
|
|
return false, fmt.Errorf("m.GetNetworkConnections: %v", err)
|
|
}
|
|
defer cl.Release()
|
|
|
|
for _, nco := range cl {
|
|
aid, err := nco.GetAdapterId()
|
|
if err != nil {
|
|
return false, fmt.Errorf("nco.GetAdapterId: %v", err)
|
|
}
|
|
if aid != ifcGUID.String() {
|
|
log.Printf("skipping adapter id: %v\n", aid)
|
|
continue
|
|
}
|
|
log.Printf("found! adapter id: %v\n", aid)
|
|
|
|
n, err := nco.GetNetwork()
|
|
if err != nil {
|
|
return false, fmt.Errorf("GetNetwork: %v", err)
|
|
}
|
|
defer n.Release()
|
|
|
|
cat, err := n.GetCategory()
|
|
if err != nil {
|
|
return false, fmt.Errorf("GetCategory: %v", err)
|
|
}
|
|
|
|
if cat == 0 {
|
|
err = n.SetCategory(1)
|
|
if err != nil {
|
|
return false, fmt.Errorf("SetCategory: %v", err)
|
|
}
|
|
} else {
|
|
log.Printf("setFirewall: already category %v\n", cat)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func configureInterface(cfg *Config, tun *tun.NativeTun) error {
|
|
const mtu = 0
|
|
guid := tun.GUID()
|
|
log.Printf("wintun GUID is %v\n", guid)
|
|
iface, err := winipcfg.InterfaceFromGUID(&guid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
// It takes a weirdly long time for Windows to notice the
|
|
// new interface has come up. Poll periodically until it
|
|
// does.
|
|
for i := 0; i < 20; i++ {
|
|
found, err := setFirewall(&guid)
|
|
if err != nil {
|
|
log.Printf("setFirewall: %v\n", err)
|
|
// fall through anyway, this isn't fatal.
|
|
}
|
|
if found {
|
|
break
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
}()
|
|
|
|
routes := []winipcfg.RouteData{}
|
|
var firstGateway4 *net.IP
|
|
var firstGateway6 *net.IP
|
|
addresses := make([]*net.IPNet, len(cfg.LocalAddrs))
|
|
for i, addr := range cfg.LocalAddrs {
|
|
ipnet := addr.IPNet()
|
|
addresses[i] = ipnet
|
|
gateway := ipnet.IP
|
|
if addr.IP.Is4() && firstGateway4 == nil {
|
|
firstGateway4 = &gateway
|
|
} else if addr.IP.Is6() && firstGateway6 == nil {
|
|
firstGateway6 = &gateway
|
|
}
|
|
}
|
|
|
|
foundDefault4 := false
|
|
foundDefault6 := false
|
|
for _, route := range cfg.Routes {
|
|
if (route.IP.Is4() && firstGateway4 == nil) || (route.IP.Is6() && firstGateway6 == nil) {
|
|
return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
|
|
}
|
|
|
|
ipn := route.IPNet()
|
|
var gateway net.IP
|
|
if route.IP.Is4() {
|
|
gateway = *firstGateway4
|
|
} else if route.IP.Is6() {
|
|
gateway = *firstGateway6
|
|
}
|
|
r := winipcfg.RouteData{
|
|
Destination: net.IPNet{
|
|
IP: ipn.IP.Mask(ipn.Mask),
|
|
Mask: ipn.Mask,
|
|
},
|
|
NextHop: gateway,
|
|
Metric: 0,
|
|
}
|
|
if bytes.Compare(r.Destination.IP, gateway) == 0 {
|
|
// no need to add a route for the interface's
|
|
// own IP. The kernel does that for us.
|
|
// If we try to replace it, we'll fail to
|
|
// add the route unless NextHop is set, but
|
|
// then the interface's IP won't be pingable.
|
|
continue
|
|
}
|
|
if route.IP.Is4() {
|
|
if route.Bits == 0 {
|
|
foundDefault4 = true
|
|
}
|
|
r.NextHop = *firstGateway4
|
|
} else if route.IP.Is6() {
|
|
if route.Bits == 0 {
|
|
foundDefault6 = true
|
|
}
|
|
r.NextHop = *firstGateway6
|
|
}
|
|
routes = append(routes, r)
|
|
}
|
|
|
|
err = iface.SyncAddresses(addresses)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sort.Slice(routes, func(i, j int) bool {
|
|
return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
|
|
// Narrower masks first
|
|
bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
|
|
// No nexthop before non-empty nexthop
|
|
bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
|
|
// Lower metrics first
|
|
routes[i].Metric < routes[j].Metric)
|
|
})
|
|
|
|
deduplicatedRoutes := []*winipcfg.RouteData{}
|
|
for i := 0; i < len(routes); i++ {
|
|
// There's only one way to get to a given IP+Mask, so delete
|
|
// all matches after the first.
|
|
if i > 0 &&
|
|
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
|
|
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
|
|
continue
|
|
}
|
|
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
|
|
}
|
|
log.Printf("routes: %v\n", routes)
|
|
|
|
var errAcc error
|
|
err = iface.SyncRoutes(deduplicatedRoutes)
|
|
if err != nil && errAcc == nil {
|
|
log.Printf("setroutes: %v\n", err)
|
|
errAcc = err
|
|
}
|
|
|
|
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
|
|
if err != nil {
|
|
log.Printf("getipif: %v\n", err)
|
|
return err
|
|
}
|
|
log.Printf("foundDefault4: %v\n", foundDefault4)
|
|
if foundDefault4 {
|
|
ipif.UseAutomaticMetric = false
|
|
ipif.Metric = 0
|
|
}
|
|
if mtu > 0 {
|
|
ipif.NlMtu = uint32(mtu)
|
|
tun.ForceMTU(int(ipif.NlMtu))
|
|
}
|
|
err = ipif.Set()
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
|
|
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
if foundDefault6 {
|
|
ipif.UseAutomaticMetric = false
|
|
ipif.Metric = 0
|
|
}
|
|
if mtu > 0 {
|
|
ipif.NlMtu = uint32(mtu)
|
|
}
|
|
ipif.DadTransmits = 0
|
|
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
|
|
err = ipif.Set()
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
|
|
return errAcc
|
|
}
|