188 lines
5.0 KiB
Go
188 lines
5.0 KiB
Go
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package wgcfg
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"go4.org/mem"
|
|
"tailscale.com/net/netaddr"
|
|
"tailscale.com/types/key"
|
|
)
|
|
|
|
type ParseError struct {
|
|
why string
|
|
offender string
|
|
}
|
|
|
|
func (e *ParseError) Error() string {
|
|
return fmt.Sprintf("%s: %q", e.why, e.offender)
|
|
}
|
|
|
|
func parseEndpoint(s string) (host string, port uint16, err error) {
|
|
i := strings.LastIndexByte(s, ':')
|
|
if i < 0 {
|
|
return "", 0, &ParseError{"Missing port from endpoint", s}
|
|
}
|
|
host, portStr := s[:i], s[i+1:]
|
|
if len(host) < 1 {
|
|
return "", 0, &ParseError{"Invalid endpoint host", host}
|
|
}
|
|
uport, err := strconv.ParseUint(portStr, 10, 16)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
hostColon := strings.IndexByte(host, ':')
|
|
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
|
|
err := &ParseError{"Brackets must contain an IPv6 address", host}
|
|
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
|
|
maybeV6 := net.ParseIP(host[1 : len(host)-1])
|
|
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
|
|
return "", 0, err
|
|
}
|
|
} else {
|
|
return "", 0, err
|
|
}
|
|
host = host[1 : len(host)-1]
|
|
}
|
|
return host, uint16(uport), nil
|
|
}
|
|
|
|
// memROCut separates a mem.RO at the separator if it exists, otherwise
|
|
// it returns two empty ROs and reports that it was not found.
|
|
func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) {
|
|
if i := mem.IndexByte(s, sep); i >= 0 {
|
|
return s.SliceTo(i), s.SliceFrom(i + 1), true
|
|
}
|
|
found = false
|
|
return
|
|
}
|
|
|
|
// FromUAPI generates a Config from r.
|
|
// r should be generated by calling device.IpcGetOperation;
|
|
// it is not compatible with other uapi streams.
|
|
func FromUAPI(r io.Reader) (*Config, error) {
|
|
cfg := new(Config)
|
|
var peer *Peer // current peer being operated on
|
|
deviceConfig := true
|
|
|
|
scanner := bufio.NewScanner(r)
|
|
for scanner.Scan() {
|
|
line := mem.B(scanner.Bytes())
|
|
if line.Len() == 0 {
|
|
continue
|
|
}
|
|
key, value, ok := memROCut(line, '=')
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy())
|
|
}
|
|
valueBytes := scanner.Bytes()[key.Len()+1:]
|
|
|
|
if key.EqualString("public_key") {
|
|
if deviceConfig {
|
|
deviceConfig = false
|
|
}
|
|
// Load/create the peer we are now configuring.
|
|
var err error
|
|
peer, err = cfg.handlePublicKeyLine(valueBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
var err error
|
|
if deviceConfig {
|
|
err = cfg.handleDeviceLine(key, value, valueBytes)
|
|
} else {
|
|
err = cfg.handlePeerLine(peer, key, value, valueBytes)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error {
|
|
switch {
|
|
case k.EqualString("private_key"):
|
|
// wireguard-go guarantees not to send zero value; private keys are already clamped.
|
|
var err error
|
|
cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case k.EqualString("listen_port") || k.EqualString("fwmark"):
|
|
// ignore
|
|
default:
|
|
return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) {
|
|
p := Peer{}
|
|
var err error
|
|
p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg.Peers = append(cfg.Peers, p)
|
|
return &cfg.Peers[len(cfg.Peers)-1], nil
|
|
}
|
|
|
|
func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error {
|
|
switch {
|
|
case k.EqualString("endpoint"):
|
|
nk, err := key.ParseNodePublicUntyped(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString())
|
|
}
|
|
// nk ought to equal peer.PublicKey.
|
|
// Under some rare circumstances, it might not. See corp issue #3016.
|
|
// Even if that happens, don't stop early, so that we can recover from it.
|
|
// Instead, note the value of nk so we can fix as needed.
|
|
peer.WGEndpoint = nk
|
|
case k.EqualString("persistent_keepalive_interval"):
|
|
n, err := mem.ParseUint(value, 10, 16)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer.PersistentKeepalive = uint16(n)
|
|
case k.EqualString("allowed_ip"):
|
|
ipp := netaddr.IPPrefix{}
|
|
err := ipp.UnmarshalText(valueBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer.AllowedIPs = append(peer.AllowedIPs, ipp)
|
|
case k.EqualString("protocol_version"):
|
|
if !value.EqualString("1") {
|
|
return fmt.Errorf("invalid protocol version: %q", value.StringCopy())
|
|
}
|
|
case k.EqualString("replace_allowed_ips") ||
|
|
k.EqualString("preshared_key") ||
|
|
k.EqualString("last_handshake_time_sec") ||
|
|
k.EqualString("last_handshake_time_nsec") ||
|
|
k.EqualString("tx_bytes") ||
|
|
k.EqualString("rx_bytes"):
|
|
// ignore
|
|
default:
|
|
return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy())
|
|
}
|
|
return nil
|
|
}
|