496 lines
16 KiB
Go
496 lines
16 KiB
Go
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package controlbase
|
|
|
|
import (
|
|
"context"
|
|
"crypto/cipher"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"time"
|
|
|
|
"go4.org/mem"
|
|
"golang.org/x/crypto/blake2s"
|
|
chp "golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/curve25519"
|
|
"golang.org/x/crypto/hkdf"
|
|
"tailscale.com/types/key"
|
|
)
|
|
|
|
const (
|
|
// protocolName is the name of the specific instantiation of Noise
|
|
// that the control protocol uses. This string's value is fixed by
|
|
// the Noise spec, and shouldn't be changed unless we're updating
|
|
// the control protocol to use a different Noise instance.
|
|
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
|
|
// protocolVersion is the version of the control protocol that
|
|
// Client will use when initiating a handshake.
|
|
//protocolVersion uint16 = 1
|
|
// protocolVersionPrefix is the name portion of the protocol
|
|
// name+version string that gets mixed into the handshake as a
|
|
// prologue.
|
|
//
|
|
// This mixing verifies that both clients agree that they're
|
|
// executing the control protocol at a specific version that
|
|
// matches the advertised version in the cleartext packet header.
|
|
protocolVersionPrefix = "Tailscale Control Protocol v"
|
|
invalidNonce = ^uint64(0)
|
|
)
|
|
|
|
func protocolVersionPrologue(version uint16) []byte {
|
|
ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers.
|
|
ret = append(ret, protocolVersionPrefix...)
|
|
return strconv.AppendUint(ret, uint64(version), 10)
|
|
}
|
|
|
|
// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn
|
|
// is assumed to have already sent the client>server handshake
|
|
// initiation message.
|
|
type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
|
|
|
|
// ClientDeferred initiates a control client handshake, returning the
|
|
// initial message to send to the server and a continuation to
|
|
// finalize the handshake.
|
|
//
|
|
// ClientDeferred is split in this way for RTT reduction: we run this
|
|
// protocol after negotiating a protocol switch from HTTP/HTTPS. If we
|
|
// completely serialized the negotiation followed by the handshake,
|
|
// we'd pay an extra RTT to transmit the handshake initiation after
|
|
// protocol switching. By splitting the handshake into an initial
|
|
// message and a continuation, we can embed the handshake initiation
|
|
// into the HTTP protocol switching request and avoid a bit of delay.
|
|
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
|
|
var s symmetricState
|
|
s.Initialize()
|
|
|
|
// prologue
|
|
s.MixHash(protocolVersionPrologue(protocolVersion))
|
|
|
|
// <- s
|
|
// ...
|
|
s.MixHash(controlKey.UntypedBytes())
|
|
|
|
// -> e, es, s, ss
|
|
init := mkInitiationMessage(protocolVersion)
|
|
machineEphemeral := key.NewMachine()
|
|
machineEphemeralPub := machineEphemeral.Public()
|
|
copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
|
|
s.MixHash(machineEphemeralPub.UntypedBytes())
|
|
cipher, err := s.MixDH(machineEphemeral, controlKey)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("computing es: %w", err)
|
|
}
|
|
machineKeyPub := machineKey.Public()
|
|
s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes())
|
|
cipher, err = s.MixDH(machineKey, controlKey)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("computing ss: %w", err)
|
|
}
|
|
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload
|
|
|
|
cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
|
|
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
|
|
}
|
|
return init[:], cont, nil
|
|
}
|
|
|
|
// Client wraps ClientDeferred and immediately invokes the returned
|
|
// continuation with conn.
|
|
//
|
|
// This is a helper for when you don't need the fancy
|
|
// continuation-style handshake, and just want to synchronously
|
|
// upgrade a net.Conn to a secure transport.
|
|
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
|
|
init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := conn.Write(init); err != nil {
|
|
return nil, err
|
|
}
|
|
return cont(ctx, conn)
|
|
}
|
|
|
|
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
|
|
// No matter what, this function can only run once per s. Ensure
|
|
// attempted reuse causes a panic.
|
|
defer func() {
|
|
s.finished = true
|
|
}()
|
|
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
if err := conn.SetDeadline(deadline); err != nil {
|
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
|
}
|
|
defer func() {
|
|
conn.SetDeadline(time.Time{})
|
|
}()
|
|
}
|
|
|
|
// Read in the payload and look for errors/protocol violations from the server.
|
|
var resp responseMessage
|
|
if _, err := io.ReadFull(conn, resp.Header()); err != nil {
|
|
return nil, fmt.Errorf("reading response header: %w", err)
|
|
}
|
|
if resp.Type() != msgTypeResponse {
|
|
if resp.Type() != msgTypeError {
|
|
return nil, fmt.Errorf("unexpected response message type %d", resp.Type())
|
|
}
|
|
msg := make([]byte, resp.Length())
|
|
if _, err := io.ReadFull(conn, msg); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("server error: %q", msg)
|
|
}
|
|
if resp.Length() != len(resp.Payload()) {
|
|
return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length())
|
|
}
|
|
if _, err := io.ReadFull(conn, resp.Payload()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// <- e, ee, se
|
|
controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub()))
|
|
s.MixHash(controlEphemeralPub.UntypedBytes())
|
|
if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil {
|
|
return nil, fmt.Errorf("computing ee: %w", err)
|
|
}
|
|
cipher, err := s.MixDH(machineKey, controlEphemeralPub)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("computing se: %w", err)
|
|
}
|
|
if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil {
|
|
return nil, fmt.Errorf("decrypting payload: %w", err)
|
|
}
|
|
|
|
c1, c2, err := s.Split()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
|
}
|
|
|
|
c := &Conn{
|
|
conn: conn,
|
|
version: protocolVersion,
|
|
peer: controlKey,
|
|
handshakeHash: s.h,
|
|
tx: txState{
|
|
cipher: c1,
|
|
},
|
|
rx: rxState{
|
|
cipher: c2,
|
|
},
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// Server initiates a control server handshake, returning the resulting
|
|
// control connection.
|
|
//
|
|
// optionalInit can be the client's initial handshake message as
|
|
// returned by ClientDeferred, or nil in which case the initial
|
|
// message is read from conn.
|
|
//
|
|
// The context deadline, if any, covers the entire handshaking
|
|
// process.
|
|
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
if err := conn.SetDeadline(deadline); err != nil {
|
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
|
}
|
|
defer func() {
|
|
conn.SetDeadline(time.Time{})
|
|
}()
|
|
}
|
|
|
|
// Deliberately does not support formatting, so that we don't echo
|
|
// attacker-controlled input back to them.
|
|
sendErr := func(msg string) error {
|
|
if len(msg) >= 1<<16 {
|
|
msg = msg[:1<<16]
|
|
}
|
|
var hdr [headerLen]byte
|
|
hdr[0] = msgTypeError
|
|
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg)))
|
|
if _, err := conn.Write(hdr[:]); err != nil {
|
|
return fmt.Errorf("sending %q error to client: %w", msg, err)
|
|
}
|
|
if _, err := io.WriteString(conn, msg); err != nil {
|
|
return fmt.Errorf("sending %q error to client: %w", msg, err)
|
|
}
|
|
return fmt.Errorf("refused client handshake: %q", msg)
|
|
}
|
|
|
|
var s symmetricState
|
|
s.Initialize()
|
|
|
|
var init initiationMessage
|
|
if optionalInit != nil {
|
|
if len(optionalInit) != len(init) {
|
|
return nil, sendErr("wrong handshake initiation size")
|
|
}
|
|
copy(init[:], optionalInit)
|
|
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
|
|
return nil, err
|
|
}
|
|
// Just a rename to make it more obvious what the value is. In the
|
|
// current implementation we don't need to block any protocol
|
|
// versions at this layer, it's safe to let the handshake proceed
|
|
// and then let the caller make decisions based on the agreed-upon
|
|
// protocol version.
|
|
clientVersion := init.Version()
|
|
if init.Type() != msgTypeInitiation {
|
|
return nil, sendErr("unexpected handshake message type")
|
|
}
|
|
if init.Length() != len(init.Payload()) {
|
|
return nil, sendErr("wrong handshake initiation length")
|
|
}
|
|
// if optionalInit was provided, we have the payload already.
|
|
if optionalInit == nil {
|
|
if _, err := io.ReadFull(conn, init.Payload()); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// prologue. Can only do this once we at least think the client is
|
|
// handshaking using a supported version.
|
|
s.MixHash(protocolVersionPrologue(clientVersion))
|
|
|
|
// <- s
|
|
// ...
|
|
controlKeyPub := controlKey.Public()
|
|
s.MixHash(controlKeyPub.UntypedBytes())
|
|
|
|
// -> e, es, s, ss
|
|
machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub()))
|
|
s.MixHash(machineEphemeralPub.UntypedBytes())
|
|
cipher, err := s.MixDH(controlKey, machineEphemeralPub)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("computing es: %w", err)
|
|
}
|
|
var machineKeyBytes [32]byte
|
|
if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil {
|
|
return nil, fmt.Errorf("decrypting machine key: %w", err)
|
|
}
|
|
machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:]))
|
|
cipher, err = s.MixDH(controlKey, machineKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("computing ss: %w", err)
|
|
}
|
|
if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil {
|
|
return nil, fmt.Errorf("decrypting initiation tag: %w", err)
|
|
}
|
|
|
|
// <- e, ee, se
|
|
resp := mkResponseMessage()
|
|
controlEphemeral := key.NewMachine()
|
|
controlEphemeralPub := controlEphemeral.Public()
|
|
copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes())
|
|
s.MixHash(controlEphemeralPub.UntypedBytes())
|
|
if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil {
|
|
return nil, fmt.Errorf("computing ee: %w", err)
|
|
}
|
|
cipher, err = s.MixDH(controlEphemeral, machineKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("computing se: %w", err)
|
|
}
|
|
s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload
|
|
|
|
c1, c2, err := s.Split()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("finalizing handshake: %w", err)
|
|
}
|
|
|
|
if _, err := conn.Write(resp[:]); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &Conn{
|
|
conn: conn,
|
|
version: clientVersion,
|
|
peer: machineKey,
|
|
handshakeHash: s.h,
|
|
tx: txState{
|
|
cipher: c2,
|
|
},
|
|
rx: rxState{
|
|
cipher: c1,
|
|
},
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// symmetricState contains the state of an in-flight handshake.
|
|
type symmetricState struct {
|
|
finished bool
|
|
|
|
h [blake2s.Size]byte // hash of currently-processed handshake state
|
|
ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake
|
|
}
|
|
|
|
func (s *symmetricState) checkFinished() {
|
|
if s.finished {
|
|
panic("attempted to use symmetricState after Split was called")
|
|
}
|
|
}
|
|
|
|
// Initialize sets s to the initial handshake state, prior to
|
|
// processing any handshake messages.
|
|
func (s *symmetricState) Initialize() {
|
|
s.checkFinished()
|
|
s.h = blake2s.Sum256([]byte(protocolName))
|
|
s.ck = s.h
|
|
}
|
|
|
|
// MixHash updates s.h to be BLAKE2s(s.h || data), where || is
|
|
// concatenation.
|
|
func (s *symmetricState) MixHash(data []byte) {
|
|
s.checkFinished()
|
|
h := newBLAKE2s()
|
|
h.Write(s.h[:])
|
|
h.Write(data)
|
|
h.Sum(s.h[:0])
|
|
}
|
|
|
|
// MixDH updates s.ck with the result of X25519(priv, pub) and returns
|
|
// a singleUseCHP that can be used to encrypt or decrypt handshake
|
|
// data.
|
|
//
|
|
// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing
|
|
// it as a single function allows for strongly-typed arguments that
|
|
// reduce the risk of error in the caller (e.g. invoking X25519 with
|
|
// two private keys, or two public keys), and thus producing the wrong
|
|
// calculation.
|
|
func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) {
|
|
s.checkFinished()
|
|
keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("computing X25519: %w", err)
|
|
}
|
|
|
|
r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil)
|
|
if _, err := io.ReadFull(r, s.ck[:]); err != nil {
|
|
return nil, fmt.Errorf("extracting ck: %w", err)
|
|
}
|
|
var k [chp.KeySize]byte
|
|
if _, err := io.ReadFull(r, k[:]); err != nil {
|
|
return nil, fmt.Errorf("extracting k: %w", err)
|
|
}
|
|
return newSingleUseCHP(k), nil
|
|
}
|
|
|
|
// EncryptAndHash encrypts plaintext into ciphertext (which must be
|
|
// the correct size to hold the encrypted plaintext) using cipher,
|
|
// mixes the ciphertext into s.h, and returns the ciphertext.
|
|
func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) {
|
|
s.checkFinished()
|
|
if len(ciphertext) != len(plaintext)+chp.Overhead {
|
|
panic("ciphertext is wrong size for given plaintext")
|
|
}
|
|
ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:])
|
|
s.MixHash(ret)
|
|
}
|
|
|
|
// DecryptAndHash decrypts the given ciphertext into plaintext (which
|
|
// must be the correct size to hold the decrypted ciphertext) using
|
|
// cipher. If decryption is successful, it mixes the ciphertext into
|
|
// s.h.
|
|
func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error {
|
|
s.checkFinished()
|
|
if len(ciphertext) != len(plaintext)+chp.Overhead {
|
|
return errors.New("plaintext is wrong size for given ciphertext")
|
|
}
|
|
if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil {
|
|
return err
|
|
}
|
|
s.MixHash(ciphertext)
|
|
return nil
|
|
}
|
|
|
|
// Split returns two ChaCha20Poly1305 ciphers with keys derived from
|
|
// the current handshake state. Methods on s cannot be used again
|
|
// after calling Split.
|
|
func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) {
|
|
s.finished = true
|
|
|
|
var k1, k2 [chp.KeySize]byte
|
|
r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil)
|
|
if _, err := io.ReadFull(r, k1[:]); err != nil {
|
|
return nil, nil, fmt.Errorf("extracting k1: %w", err)
|
|
}
|
|
if _, err := io.ReadFull(r, k2[:]); err != nil {
|
|
return nil, nil, fmt.Errorf("extracting k2: %w", err)
|
|
}
|
|
c1, err = chp.New(k1[:])
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err)
|
|
}
|
|
c2, err = chp.New(k2[:])
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err)
|
|
}
|
|
return c1, c2, nil
|
|
}
|
|
|
|
// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on
|
|
// error.
|
|
func newBLAKE2s() hash.Hash {
|
|
h, err := blake2s.New256(nil)
|
|
if err != nil {
|
|
// Should never happen, errors only happen when using BLAKE2s
|
|
// in MAC mode with a key.
|
|
panic(err)
|
|
}
|
|
return h
|
|
}
|
|
|
|
// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or
|
|
// panics on error.
|
|
func newCHP(key [chp.KeySize]byte) cipher.AEAD {
|
|
aead, err := chp.New(key[:])
|
|
if err != nil {
|
|
// Can only happen if we passed a key of the wrong length. The
|
|
// function signature prevents that.
|
|
panic(err)
|
|
}
|
|
return aead
|
|
}
|
|
|
|
// singleUseCHP is an instance of ChaCha20Poly1305 that can be used
|
|
// only once, either for encrypting or decrypting, but not both. The
|
|
// chosen operation is always executed with an all-zeros
|
|
// nonce. Subsequent calls to either Seal or Open panic.
|
|
type singleUseCHP struct {
|
|
c cipher.AEAD
|
|
}
|
|
|
|
func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP {
|
|
return &singleUseCHP{newCHP(key)}
|
|
}
|
|
|
|
func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte {
|
|
if c.c == nil {
|
|
panic("Attempted reuse of singleUseAEAD")
|
|
}
|
|
cipher := c.c
|
|
c.c = nil
|
|
var nonce [chp.NonceSize]byte
|
|
return cipher.Seal(dst, nonce[:], plaintext, additionalData)
|
|
}
|
|
|
|
func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) {
|
|
if c.c == nil {
|
|
panic("Attempted reuse of singleUseAEAD")
|
|
}
|
|
cipher := c.c
|
|
c.c = nil
|
|
var nonce [chp.NonceSize]byte
|
|
return cipher.Open(dst, nonce[:], ciphertext, additionalData)
|
|
}
|