tailscale/wgengine/bench/wg.go

233 lines
5.5 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"errors"
"io"
"log"
"os"
"sync"
"testing"
"github.com/tailscale/wireguard-go/tun"
"inet.af/netaddr"
"tailscale.com/net/dns"
"tailscale.com/tailcfg"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/wgkey"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg"
)
func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netaddr.IPPrefix) {
l1 := logger.WithPrefix(logf, "e1: ")
k1, err := wgkey.NewPrivate()
if err != nil {
log.Fatalf("e1 NewPrivateKey: %v", err)
}
c1 := wgcfg.Config{
Name: "e1",
PrivateKey: k1,
Addresses: []netaddr.IPPrefix{a1},
}
t1 := &sourceTun{
logf: logger.WithPrefix(logf, "tun1: "),
traf: traf,
}
e1, err := wgengine.NewUserspaceEngine(l1, wgengine.Config{
Router: router.NewFake(l1),
LinkMonitor: nil,
ListenPort: 0,
Tun: t1,
})
if err != nil {
log.Fatalf("e1 init: %v", err)
}
if b != nil {
b.Cleanup(e1.Close)
}
l2 := logger.WithPrefix(logf, "e2: ")
k2, err := wgkey.NewPrivate()
if err != nil {
log.Fatalf("e2 NewPrivateKey: %v", err)
}
c2 := wgcfg.Config{
Name: "e2",
PrivateKey: k2,
Addresses: []netaddr.IPPrefix{a2},
}
t2 := &sinkTun{
logf: logger.WithPrefix(logf, "tun2: "),
traf: traf,
}
e2, err := wgengine.NewUserspaceEngine(l2, wgengine.Config{
Router: router.NewFake(l2),
LinkMonitor: nil,
ListenPort: 0,
Tun: t2,
})
if err != nil {
log.Fatalf("e2 init: %v", err)
}
if b != nil {
b.Cleanup(e2.Close)
}
e1.SetFilter(filter.NewAllowAllForTest(l1))
e2.SetFilter(filter.NewAllowAllForTest(l2))
var wait sync.WaitGroup
wait.Add(2)
var e1waitDoneOnce sync.Once
e1.SetStatusCallback(func(st *wgengine.Status, err error) {
if errors.Is(err, wgengine.ErrEngineClosing) {
return
}
if err != nil {
log.Fatalf("e1 status err: %v", err)
}
logf("e1 status: %v", *st)
var eps []string
var ipps []netaddr.IPPort
for _, ep := range st.LocalAddrs {
eps = append(eps, ep.Addr.String())
ipps = append(ipps, ep.Addr)
}
endpoint := wgcfg.Endpoints{
PublicKey: c1.PrivateKey.Public(),
IPPorts: wgcfg.NewIPPortSet(ipps...),
}
n := tailcfg.Node{
ID: tailcfg.NodeID(0),
Name: "n1",
Addresses: []netaddr.IPPrefix{a1},
AllowedIPs: []netaddr.IPPrefix{a1},
Endpoints: eps,
}
e2.SetNetworkMap(&netmap.NetworkMap{
NodeKey: tailcfg.NodeKey(k2),
PrivateKey: wgkey.Private(k2),
Peers: []*tailcfg.Node{&n},
})
p := wgcfg.Peer{
PublicKey: c1.PrivateKey.Public(),
AllowedIPs: []netaddr.IPPrefix{a1},
Endpoints: endpoint,
}
c2.Peers = []wgcfg.Peer{p}
e2.Reconfig(&c2, &router.Config{}, new(dns.Config))
e1waitDoneOnce.Do(wait.Done)
})
var e2waitDoneOnce sync.Once
e2.SetStatusCallback(func(st *wgengine.Status, err error) {
if errors.Is(err, wgengine.ErrEngineClosing) {
return
}
if err != nil {
log.Fatalf("e2 status err: %v", err)
}
logf("e2 status: %v", *st)
var eps []string
var ipps []netaddr.IPPort
for _, ep := range st.LocalAddrs {
eps = append(eps, ep.Addr.String())
ipps = append(ipps, ep.Addr)
}
endpoint := wgcfg.Endpoints{
PublicKey: c2.PrivateKey.Public(),
IPPorts: wgcfg.NewIPPortSet(ipps...),
}
n := tailcfg.Node{
ID: tailcfg.NodeID(0),
Name: "n2",
Addresses: []netaddr.IPPrefix{a2},
AllowedIPs: []netaddr.IPPrefix{a2},
Endpoints: eps,
}
e1.SetNetworkMap(&netmap.NetworkMap{
NodeKey: tailcfg.NodeKey(k1),
PrivateKey: wgkey.Private(k1),
Peers: []*tailcfg.Node{&n},
})
p := wgcfg.Peer{
PublicKey: c2.PrivateKey.Public(),
AllowedIPs: []netaddr.IPPrefix{a2},
Endpoints: endpoint,
}
c1.Peers = []wgcfg.Peer{p}
e1.Reconfig(&c1, &router.Config{}, new(dns.Config))
e2waitDoneOnce.Do(wait.Done)
})
// Not using DERP in this test (for now?).
e1.SetDERPMap(&tailcfg.DERPMap{})
e2.SetDERPMap(&tailcfg.DERPMap{})
wait.Wait()
}
type sourceTun struct {
logf logger.Logf
traf *TrafficGen
}
func (t *sourceTun) Close() error { return nil }
func (t *sourceTun) Events() chan tun.Event { return nil }
func (t *sourceTun) File() *os.File { return nil }
func (t *sourceTun) Flush() error { return nil }
func (t *sourceTun) MTU() (int, error) { return 1500, nil }
func (t *sourceTun) Name() (string, error) { return "source", nil }
func (t *sourceTun) Write(b []byte, ofs int) (int, error) {
// Discard all writes
return len(b) - ofs, nil
}
func (t *sourceTun) Read(b []byte, ofs int) (int, error) {
// Continually generate "input" packets
n := t.traf.Generate(b, ofs)
if n == 0 {
return 0, io.EOF
}
return n, nil
}
type sinkTun struct {
logf logger.Logf
traf *TrafficGen
}
func (t *sinkTun) Close() error { return nil }
func (t *sinkTun) Events() chan tun.Event { return nil }
func (t *sinkTun) File() *os.File { return nil }
func (t *sinkTun) Flush() error { return nil }
func (t *sinkTun) MTU() (int, error) { return 1500, nil }
func (t *sinkTun) Name() (string, error) { return "sink", nil }
func (t *sinkTun) Read(b []byte, ofs int) (int, error) {
// Never returns
select {}
}
func (t *sinkTun) Write(b []byte, ofs int) (int, error) {
// Count packets, but discard them
t.traf.GotPacket(b, ofs)
return len(b) - ofs, nil
}