193 lines
3.8 KiB
Go
193 lines
3.8 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
package chirp
|
|
|
|
import (
|
|
"bufio"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type fakeBIRD struct {
|
|
net.Listener
|
|
protocolsEnabled map[string]bool
|
|
sock string
|
|
}
|
|
|
|
func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD {
|
|
sock := filepath.Join(t.TempDir(), "sock")
|
|
l, err := net.Listen("unix", sock)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pe := make(map[string]bool)
|
|
for _, p := range protocols {
|
|
pe[p] = false
|
|
}
|
|
return &fakeBIRD{
|
|
Listener: l,
|
|
protocolsEnabled: pe,
|
|
sock: sock,
|
|
}
|
|
}
|
|
|
|
func (fb *fakeBIRD) listen() error {
|
|
for {
|
|
c, err := fb.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
go fb.handle(c)
|
|
}
|
|
}
|
|
|
|
func (fb *fakeBIRD) handle(c net.Conn) {
|
|
fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.")
|
|
sc := bufio.NewScanner(c)
|
|
for sc.Scan() {
|
|
cmd := sc.Text()
|
|
args := strings.Split(cmd, " ")
|
|
switch args[0] {
|
|
case "enable":
|
|
en, ok := fb.protocolsEnabled[args[1]]
|
|
if !ok {
|
|
fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL")
|
|
} else if en {
|
|
fmt.Fprintf(c, "0010-%s: already enabled\n", args[1])
|
|
} else {
|
|
fmt.Fprintf(c, "0011-%s: enabled\n", args[1])
|
|
}
|
|
fmt.Fprintln(c, "0000 ")
|
|
fb.protocolsEnabled[args[1]] = true
|
|
case "disable":
|
|
en, ok := fb.protocolsEnabled[args[1]]
|
|
if !ok {
|
|
fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL")
|
|
} else if !en {
|
|
fmt.Fprintf(c, "0008-%s: already disabled\n", args[1])
|
|
} else {
|
|
fmt.Fprintf(c, "0009-%s: disabled\n", args[1])
|
|
}
|
|
fmt.Fprintln(c, "0000 ")
|
|
fb.protocolsEnabled[args[1]] = false
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestChirp(t *testing.T) {
|
|
fb := newFakeBIRD(t, "tailscale")
|
|
defer fb.Close()
|
|
go fb.listen()
|
|
c, err := New(fb.sock)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.EnableProtocol("tailscale"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.EnableProtocol("tailscale"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DisableProtocol("tailscale"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DisableProtocol("tailscale"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.EnableProtocol("rando"); err == nil {
|
|
t.Fatalf("enabling %q succeeded", "rando")
|
|
}
|
|
if err := c.DisableProtocol("rando"); err == nil {
|
|
t.Fatalf("disabling %q succeeded", "rando")
|
|
}
|
|
}
|
|
|
|
type hangingListener struct {
|
|
net.Listener
|
|
t *testing.T
|
|
done chan struct{}
|
|
wg sync.WaitGroup
|
|
sock string
|
|
}
|
|
|
|
func newHangingListener(t *testing.T) *hangingListener {
|
|
sock := filepath.Join(t.TempDir(), "sock")
|
|
l, err := net.Listen("unix", sock)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &hangingListener{
|
|
Listener: l,
|
|
t: t,
|
|
done: make(chan struct{}),
|
|
sock: sock,
|
|
}
|
|
}
|
|
|
|
func (hl *hangingListener) Stop() {
|
|
hl.Close()
|
|
close(hl.done)
|
|
hl.wg.Wait()
|
|
}
|
|
|
|
func (hl *hangingListener) listen() error {
|
|
for {
|
|
c, err := hl.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
hl.wg.Add(1)
|
|
go hl.handle(c)
|
|
}
|
|
}
|
|
|
|
func (hl *hangingListener) handle(c net.Conn) {
|
|
defer hl.wg.Done()
|
|
|
|
// Write our fake first line of response so that we get into the read loop
|
|
fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.")
|
|
|
|
ticker := time.NewTicker(2 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
hl.t.Logf("connection still hanging")
|
|
case <-hl.done:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestChirpTimeout(t *testing.T) {
|
|
fb := newHangingListener(t)
|
|
defer fb.Stop()
|
|
go fb.listen()
|
|
|
|
c, err := newWithTimeout(fb.sock, 500*time.Millisecond)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = c.EnableProtocol("tailscale")
|
|
if err == nil {
|
|
t.Fatal("got err=nil, want timeout")
|
|
}
|
|
if !os.IsTimeout(err) {
|
|
t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err)
|
|
}
|
|
}
|