tailscale/wgengine/wgcfg/parser.go

199 lines
5.3 KiB
Go
Raw Permalink Normal View History

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgcfg
import (
"bufio"
"encoding/hex"
"fmt"
"io"
"net"
"strconv"
"strings"
"go4.org/mem"
"inet.af/netaddr"
"tailscale.com/types/wgkey"
)
type ParseError struct {
why string
offender string
}
func (e *ParseError) Error() string {
return fmt.Sprintf("%s: %q", e.why, e.offender)
}
func parseEndpoint(s string) (host string, port uint16, err error) {
i := strings.LastIndexByte(s, ':')
if i < 0 {
return "", 0, &ParseError{"Missing port from endpoint", s}
}
host, portStr := s[:i], s[i+1:]
if len(host) < 1 {
return "", 0, &ParseError{"Invalid endpoint host", host}
}
uport, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, err
}
hostColon := strings.IndexByte(host, ':')
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
err := &ParseError{"Brackets must contain an IPv6 address", host}
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
maybeV6 := net.ParseIP(host[1 : len(host)-1])
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
return "", 0, err
}
} else {
return "", 0, err
}
host = host[1 : len(host)-1]
}
return host, uint16(uport), nil
}
// memROCut separates a mem.RO at the separator if it exists, otherwise
// it returns two empty ROs and reports that it was not found.
func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) {
if i := mem.IndexByte(s, sep); i >= 0 {
return s.SliceTo(i), s.SliceFrom(i + 1), true
}
found = false
return
}
// FromUAPI generates a Config from r.
// r should be generated by calling device.IpcGetOperation;
// it is not compatible with other uapi streams.
func FromUAPI(r io.Reader) (*Config, error) {
cfg := new(Config)
var peer *Peer // current peer being operated on
deviceConfig := true
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := mem.B(scanner.Bytes())
if line.Len() == 0 {
continue
}
key, value, ok := memROCut(line, '=')
if !ok {
return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy())
}
valueBytes := scanner.Bytes()[key.Len()+1:]
if key.EqualString("public_key") {
if deviceConfig {
deviceConfig = false
}
// Load/create the peer we are now configuring.
var err error
peer, err = cfg.handlePublicKeyLine(valueBytes)
if err != nil {
return nil, err
}
continue
}
var err error
if deviceConfig {
err = cfg.handleDeviceLine(key, value, valueBytes)
} else {
err = cfg.handlePeerLine(peer, key, value, valueBytes)
}
if err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return cfg, nil
}
func parseKeyHex(s []byte, dst []byte) error {
n, err := hex.Decode(dst, s)
if err != nil {
return &ParseError{"Invalid key: " + err.Error(), string(s)}
}
if n != wgkey.Size {
return &ParseError{"Keys must decode to exactly 32 bytes", string(s)}
}
return nil
}
func (cfg *Config) handleDeviceLine(key, value mem.RO, valueBytes []byte) error {
switch {
case key.EqualString("private_key"):
// wireguard-go guarantees not to send zero value; private keys are already clamped.
if err := parseKeyHex(valueBytes, cfg.PrivateKey[:]); err != nil {
return err
}
case key.EqualString("listen_port") || key.EqualString("fwmark"):
// ignore
default:
return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy())
}
return nil
}
func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) {
p := Peer{}
if err := parseKeyHex(valueBytes, p.PublicKey[:]); err != nil {
return nil, err
}
cfg.Peers = append(cfg.Peers, p)
return &cfg.Peers[len(cfg.Peers)-1], nil
}
func (cfg *Config) handlePeerLine(peer *Peer, key, value mem.RO, valueBytes []byte) error {
switch {
case key.EqualString("endpoint"):
// TODO: our key types are all over the place, and this
// particular one can't parse a mem.RO or a []byte without
// allocating. We don't reconfigure wireguard often though, so
// this is okay.
s := value.StringCopy()
k, err := wgkey.ParseHex(s)
if err != nil {
return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", s, peer.PublicKey.ShortString())
}
if k != peer.PublicKey {
return fmt.Errorf("unexpected endpoint %q for peer %q, expected the peer's public key", s, peer.PublicKey.ShortString())
}
case key.EqualString("persistent_keepalive_interval"):
n, err := mem.ParseUint(value, 10, 16)
if err != nil {
return err
}
peer.PersistentKeepalive = uint16(n)
case key.EqualString("allowed_ip"):
ipp := netaddr.IPPrefix{}
err := ipp.UnmarshalText(valueBytes)
if err != nil {
return err
}
peer.AllowedIPs = append(peer.AllowedIPs, ipp)
case key.EqualString("protocol_version"):
if !value.EqualString("1") {
return fmt.Errorf("invalid protocol version: %q", value.StringCopy())
}
case key.EqualString("replace_allowed_ips") ||
key.EqualString("preshared_key") ||
key.EqualString("last_handshake_time_sec") ||
key.EqualString("last_handshake_time_nsec") ||
key.EqualString("tx_bytes") ||
key.EqualString("rx_bytes"):
// ignore
default:
return fmt.Errorf("unexpected IpcGetOperation key: %q", key.StringCopy())
}
return nil
}