191 lines
4.8 KiB
Go
191 lines
4.8 KiB
Go
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package wgcfg
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"go4.org/mem"
|
|
"inet.af/netaddr"
|
|
"tailscale.com/types/wgkey"
|
|
)
|
|
|
|
type ParseError struct {
|
|
why string
|
|
offender string
|
|
}
|
|
|
|
func (e *ParseError) Error() string {
|
|
return fmt.Sprintf("%s: %q", e.why, e.offender)
|
|
}
|
|
|
|
func parseEndpoint(s string) (host string, port uint16, err error) {
|
|
i := strings.LastIndexByte(s, ':')
|
|
if i < 0 {
|
|
return "", 0, &ParseError{"Missing port from endpoint", s}
|
|
}
|
|
host, portStr := s[:i], s[i+1:]
|
|
if len(host) < 1 {
|
|
return "", 0, &ParseError{"Invalid endpoint host", host}
|
|
}
|
|
uport, err := strconv.ParseUint(portStr, 10, 16)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
hostColon := strings.IndexByte(host, ':')
|
|
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
|
|
err := &ParseError{"Brackets must contain an IPv6 address", host}
|
|
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
|
|
maybeV6 := net.ParseIP(host[1 : len(host)-1])
|
|
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
|
|
return "", 0, err
|
|
}
|
|
} else {
|
|
return "", 0, err
|
|
}
|
|
host = host[1 : len(host)-1]
|
|
}
|
|
return host, uint16(uport), nil
|
|
}
|
|
|
|
// memROCut separates a mem.RO at the separator if it exists, otherwise
|
|
// it returns two empty ROs and reports that it was not found.
|
|
func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) {
|
|
if i := mem.IndexByte(s, sep); i >= 0 {
|
|
return s.SliceTo(i), s.SliceFrom(i + 1), true
|
|
}
|
|
found = false
|
|
return
|
|
}
|
|
|
|
// FromUAPI generates a Config from r.
|
|
// r should be generated by calling device.IpcGetOperation;
|
|
// it is not compatible with other uapi streams.
|
|
func FromUAPI(r io.Reader) (*Config, error) {
|
|
cfg := new(Config)
|
|
var peer *Peer // current peer being operated on
|
|
deviceConfig := true
|
|
|
|
scanner := bufio.NewScanner(r)
|
|
for scanner.Scan() {
|
|
line := mem.B(scanner.Bytes())
|
|
if line.Len() == 0 {
|
|
continue
|
|
}
|
|
key, value, ok := memROCut(line, '=')
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy())
|
|
}
|
|
valueBytes := scanner.Bytes()[key.Len()+1:]
|
|
|
|
if key.EqualString("public_key") {
|
|
if deviceConfig {
|
|
deviceConfig = false
|
|
}
|
|
// Load/create the peer we are now configuring.
|
|
var err error
|
|
peer, err = cfg.handlePublicKeyLine(valueBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
var err error
|
|
if deviceConfig {
|
|
err = cfg.handleDeviceLine(key, value, valueBytes)
|
|
} else {
|
|
err = cfg.handlePeerLine(peer, key, value, valueBytes)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func parseKeyHex(s []byte, dst []byte) error {
|
|
n, err := hex.Decode(dst, s)
|
|
if err != nil {
|
|
return &ParseError{"Invalid key: " + err.Error(), string(s)}
|
|
}
|
|
if n != wgkey.Size {
|
|
return &ParseError{"Keys must decode to exactly 32 bytes", string(s)}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cfg *Config) handleDeviceLine(key, value mem.RO, valueBytes []byte) error {
|
|
switch {
|
|
case key.EqualString("private_key"):
|
|
// wireguard-go guarantees not to send zero value; private keys are already clamped.
|
|
if err := parseKeyHex(valueBytes, cfg.PrivateKey[:]); err != nil {
|
|
return err
|
|
}
|
|
case key.EqualString("listen_port") || key.EqualString("fwmark"):
|
|
// ignore
|
|
default:
|
|
return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) {
|
|
p := Peer{}
|
|
if err := parseKeyHex(valueBytes, p.PublicKey[:]); err != nil {
|
|
return nil, err
|
|
}
|
|
cfg.Peers = append(cfg.Peers, p)
|
|
return &cfg.Peers[len(cfg.Peers)-1], nil
|
|
}
|
|
|
|
func (cfg *Config) handlePeerLine(peer *Peer, key, value mem.RO, valueBytes []byte) error {
|
|
switch {
|
|
case key.EqualString("endpoint"):
|
|
if err := json.Unmarshal(valueBytes, &peer.Endpoints); err != nil {
|
|
return err
|
|
}
|
|
case key.EqualString("persistent_keepalive_interval"):
|
|
n, err := mem.ParseUint(value, 10, 16)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer.PersistentKeepalive = uint16(n)
|
|
case key.EqualString("allowed_ip"):
|
|
ipp := netaddr.IPPrefix{}
|
|
err := ipp.UnmarshalText(valueBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
peer.AllowedIPs = append(peer.AllowedIPs, ipp)
|
|
case key.EqualString("protocol_version"):
|
|
if !value.EqualString("1") {
|
|
return fmt.Errorf("invalid protocol version: %q", value.StringCopy())
|
|
}
|
|
case key.EqualString("replace_allowed_ips") ||
|
|
key.EqualString("preshared_key") ||
|
|
key.EqualString("last_handshake_time_sec") ||
|
|
key.EqualString("last_handshake_time_nsec") ||
|
|
key.EqualString("tx_bytes") ||
|
|
key.EqualString("rx_bytes"):
|
|
// ignore
|
|
default:
|
|
return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy())
|
|
}
|
|
return nil
|
|
}
|