243 lines
5.5 KiB
Go
243 lines
5.5 KiB
Go
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package wgcfg
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"io"
|
||
|
"os"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/tailscale/wireguard-go/device"
|
||
|
"github.com/tailscale/wireguard-go/tun"
|
||
|
"inet.af/netaddr"
|
||
|
"tailscale.com/types/wgkey"
|
||
|
)
|
||
|
|
||
|
func TestDeviceConfig(t *testing.T) {
|
||
|
newPrivateKey := func() (Key, PrivateKey) {
|
||
|
t.Helper()
|
||
|
pk, err := wgkey.NewPrivate()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
return Key(pk.Public()), PrivateKey(pk)
|
||
|
}
|
||
|
k1, pk1 := newPrivateKey()
|
||
|
ip1 := netaddr.MustParseIPPrefix("10.0.0.1/32")
|
||
|
|
||
|
k2, pk2 := newPrivateKey()
|
||
|
ip2 := netaddr.MustParseIPPrefix("10.0.0.2/32")
|
||
|
|
||
|
k3, _ := newPrivateKey()
|
||
|
ip3 := netaddr.MustParseIPPrefix("10.0.0.3/32")
|
||
|
|
||
|
cfg1 := &Config{
|
||
|
PrivateKey: PrivateKey(pk1),
|
||
|
Peers: []Peer{{
|
||
|
PublicKey: k2,
|
||
|
AllowedIPs: []netaddr.IPPrefix{ip2},
|
||
|
}},
|
||
|
}
|
||
|
|
||
|
cfg2 := &Config{
|
||
|
PrivateKey: PrivateKey(pk2),
|
||
|
Peers: []Peer{{
|
||
|
PublicKey: k1,
|
||
|
AllowedIPs: []netaddr.IPPrefix{ip1},
|
||
|
PersistentKeepalive: 5,
|
||
|
}},
|
||
|
}
|
||
|
|
||
|
device1 := device.NewDevice(newNilTun(), &device.DeviceOptions{
|
||
|
Logger: device.NewLogger(device.LogLevelError, "device1"),
|
||
|
})
|
||
|
device2 := device.NewDevice(newNilTun(), &device.DeviceOptions{
|
||
|
Logger: device.NewLogger(device.LogLevelError, "device2"),
|
||
|
})
|
||
|
defer device1.Close()
|
||
|
defer device2.Close()
|
||
|
|
||
|
cmp := func(t *testing.T, d *device.Device, want *Config) {
|
||
|
t.Helper()
|
||
|
got, err := DeviceConfig(d)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
prev := new(Config)
|
||
|
gotbuf := new(strings.Builder)
|
||
|
err = got.ToUAPI(gotbuf, prev)
|
||
|
gotStr := gotbuf.String()
|
||
|
if err != nil {
|
||
|
t.Errorf("got.ToUAPI(): error: %v", err)
|
||
|
return
|
||
|
}
|
||
|
wantbuf := new(strings.Builder)
|
||
|
err = want.ToUAPI(wantbuf, prev)
|
||
|
wantStr := wantbuf.String()
|
||
|
if err != nil {
|
||
|
t.Errorf("want.ToUAPI(): error: %v", err)
|
||
|
return
|
||
|
}
|
||
|
if gotStr != wantStr {
|
||
|
buf := new(bytes.Buffer)
|
||
|
w := bufio.NewWriter(buf)
|
||
|
if err := d.IpcGetOperation(w); err != nil {
|
||
|
t.Errorf("on error, could not IpcGetOperation: %v", err)
|
||
|
}
|
||
|
w.Flush()
|
||
|
t.Errorf("cfg:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t.Run("device1 config", func(t *testing.T) {
|
||
|
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device1, cfg1)
|
||
|
})
|
||
|
|
||
|
t.Run("device2 config", func(t *testing.T) {
|
||
|
if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device2, cfg2)
|
||
|
})
|
||
|
|
||
|
// This is only to test that Config and Reconfig are properly synchronized.
|
||
|
t.Run("device2 config/reconfig", func(t *testing.T) {
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(2)
|
||
|
|
||
|
go func() {
|
||
|
ReconfigDevice(device2, cfg2, t.Logf)
|
||
|
wg.Done()
|
||
|
}()
|
||
|
|
||
|
go func() {
|
||
|
DeviceConfig(device2)
|
||
|
wg.Done()
|
||
|
}()
|
||
|
|
||
|
wg.Wait()
|
||
|
})
|
||
|
|
||
|
t.Run("device1 modify peer", func(t *testing.T) {
|
||
|
cfg1.Peers[0].Endpoints = "1.2.3.4:12345"
|
||
|
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device1, cfg1)
|
||
|
})
|
||
|
|
||
|
t.Run("device1 replace endpoint", func(t *testing.T) {
|
||
|
cfg1.Peers[0].Endpoints = "1.1.1.1:123"
|
||
|
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device1, cfg1)
|
||
|
})
|
||
|
|
||
|
t.Run("device1 add new peer", func(t *testing.T) {
|
||
|
cfg1.Peers = append(cfg1.Peers, Peer{
|
||
|
PublicKey: k3,
|
||
|
AllowedIPs: []netaddr.IPPrefix{ip3},
|
||
|
})
|
||
|
sort.Slice(cfg1.Peers, func(i, j int) bool {
|
||
|
return cfg1.Peers[i].PublicKey.LessThan(&cfg1.Peers[j].PublicKey)
|
||
|
})
|
||
|
|
||
|
origCfg, err := DeviceConfig(device1)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device1, cfg1)
|
||
|
|
||
|
newCfg, err := DeviceConfig(device1)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
peer0 := func(cfg *Config) Peer {
|
||
|
p, ok := cfg.PeerWithKey(k2)
|
||
|
if !ok {
|
||
|
t.Helper()
|
||
|
t.Fatal("failed to look up peer 2")
|
||
|
}
|
||
|
return p
|
||
|
}
|
||
|
peersEqual := func(p, q Peer) bool {
|
||
|
return p.PublicKey == q.PublicKey && p.PersistentKeepalive == q.PersistentKeepalive &&
|
||
|
p.Endpoints == q.Endpoints && cidrsEqual(p.AllowedIPs, q.AllowedIPs)
|
||
|
}
|
||
|
if !peersEqual(peer0(origCfg), peer0(newCfg)) {
|
||
|
t.Error("reconfig modified old peer")
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("device1 remove peer", func(t *testing.T) {
|
||
|
removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey
|
||
|
cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1]
|
||
|
|
||
|
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
cmp(t, device1, cfg1)
|
||
|
|
||
|
newCfg, err := DeviceConfig(device1)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
_, ok := newCfg.PeerWithKey(removeKey)
|
||
|
if ok {
|
||
|
t.Error("reconfig failed to remove peer")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// TODO: replace with a loopback tunnel
|
||
|
type nilTun struct {
|
||
|
events chan tun.Event
|
||
|
closed chan struct{}
|
||
|
}
|
||
|
|
||
|
func newNilTun() tun.Device {
|
||
|
return &nilTun{
|
||
|
events: make(chan tun.Event),
|
||
|
closed: make(chan struct{}),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (t *nilTun) File() *os.File { return nil }
|
||
|
func (t *nilTun) Flush() error { return nil }
|
||
|
func (t *nilTun) MTU() (int, error) { return 1420, nil }
|
||
|
func (t *nilTun) Name() (string, error) { return "niltun", nil }
|
||
|
func (t *nilTun) Events() chan tun.Event { return t.events }
|
||
|
|
||
|
func (t *nilTun) Read(data []byte, offset int) (int, error) {
|
||
|
<-t.closed
|
||
|
return 0, io.EOF
|
||
|
}
|
||
|
|
||
|
func (t *nilTun) Write(data []byte, offset int) (int, error) {
|
||
|
<-t.closed
|
||
|
return 0, io.EOF
|
||
|
}
|
||
|
|
||
|
func (t *nilTun) Close() error {
|
||
|
close(t.events)
|
||
|
close(t.closed)
|
||
|
return nil
|
||
|
}
|