363 lines
10 KiB
Go
363 lines
10 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// The pgproxy server is a proxy for the Postgres wire protocol.
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
crand "crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"expvar"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"tailscale.com/client/tailscale"
|
|
"tailscale.com/metrics"
|
|
"tailscale.com/tsnet"
|
|
"tailscale.com/tsweb"
|
|
)
|
|
|
|
var (
|
|
hostname = flag.String("hostname", "", "Tailscale hostname to serve on")
|
|
port = flag.Int("port", 5432, "Listening port for client connections")
|
|
debugPort = flag.Int("debug-port", 80, "Listening port for debug/metrics endpoint")
|
|
upstreamAddr = flag.String("upstream-addr", "", "Address of the upstream Postgres server, in host:port format")
|
|
upstreamCA = flag.String("upstream-ca-file", "", "File containing the PEM-encoded CA certificate for the upstream server")
|
|
tailscaleDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
if *hostname == "" {
|
|
log.Fatal("missing --hostname")
|
|
}
|
|
if *upstreamAddr == "" {
|
|
log.Fatal("missing --upstream-addr")
|
|
}
|
|
if *upstreamCA == "" {
|
|
log.Fatal("missing --upstream-ca-file")
|
|
}
|
|
if *tailscaleDir == "" {
|
|
log.Fatal("missing --state-dir")
|
|
}
|
|
|
|
ts := &tsnet.Server{
|
|
Dir: *tailscaleDir,
|
|
Hostname: *hostname,
|
|
}
|
|
|
|
if os.Getenv("TS_AUTHKEY") == "" {
|
|
log.Print("Note: you need to run this with TS_AUTHKEY=... the first time, to join your tailnet of choice.")
|
|
}
|
|
|
|
tsclient, err := ts.LocalClient()
|
|
if err != nil {
|
|
log.Fatalf("getting tsnet API client: %v", err)
|
|
}
|
|
|
|
p, err := newProxy(*upstreamAddr, *upstreamCA, tsclient)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
expvar.Publish("pgproxy", p.Expvar())
|
|
|
|
if *debugPort != 0 {
|
|
mux := http.NewServeMux()
|
|
tsweb.Debugger(mux)
|
|
srv := &http.Server{
|
|
Handler: mux,
|
|
}
|
|
dln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
go func() {
|
|
log.Fatal(srv.Serve(dln))
|
|
}()
|
|
}
|
|
|
|
ln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *port))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
log.Printf("serving access to %s on port %d", *upstreamAddr, *port)
|
|
log.Fatal(p.Serve(ln))
|
|
}
|
|
|
|
// proxy is a postgres wire protocol proxy, which strictly enforces
|
|
// the security of the TLS connection to its upstream regardless of
|
|
// what the client's TLS configuration is.
|
|
type proxy struct {
|
|
upstreamAddr string // "my.database.com:5432"
|
|
upstreamHost string // "my.database.com"
|
|
upstreamCertPool *x509.CertPool
|
|
downstreamCert []tls.Certificate
|
|
client *tailscale.LocalClient
|
|
|
|
activeSessions expvar.Int
|
|
startedSessions expvar.Int
|
|
errors metrics.LabelMap
|
|
}
|
|
|
|
// newProxy returns a proxy that forwards connections to
|
|
// upstreamAddr. The upstream's TLS session is verified using the CA
|
|
// cert(s) in upstreamCAPath.
|
|
func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) {
|
|
bs, err := os.ReadFile(upstreamCAPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upstreamCertPool := x509.NewCertPool()
|
|
if !upstreamCertPool.AppendCertsFromPEM(bs) {
|
|
return nil, fmt.Errorf("invalid CA cert in %q", upstreamCAPath)
|
|
}
|
|
|
|
h, _, err := net.SplitHostPort(upstreamAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
downstreamCert, err := mkSelfSigned(h)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &proxy{
|
|
upstreamAddr: upstreamAddr,
|
|
upstreamHost: h,
|
|
upstreamCertPool: upstreamCertPool,
|
|
downstreamCert: []tls.Certificate{downstreamCert},
|
|
client: client,
|
|
errors: metrics.LabelMap{Label: "kind"},
|
|
}, nil
|
|
}
|
|
|
|
// Expvar returns p's monitoring metrics.
|
|
func (p *proxy) Expvar() expvar.Var {
|
|
ret := &metrics.Set{}
|
|
ret.Set("sessions_active", &p.activeSessions)
|
|
ret.Set("sessions_started", &p.startedSessions)
|
|
ret.Set("session_errors", &p.errors)
|
|
return ret
|
|
}
|
|
|
|
// Serve accepts postgres client connections on ln and proxies them to
|
|
// the configured upstream. ln can be any net.Listener, but all client
|
|
// connections must originate from tailscale IPs that can be verified
|
|
// with WhoIs.
|
|
func (p *proxy) Serve(ln net.Listener) error {
|
|
var lastSessionID int64
|
|
for {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id := time.Now().UnixNano()
|
|
if id == lastSessionID {
|
|
// Bluntly enforce SID uniqueness, even if collisions are
|
|
// fantastically unlikely (but OSes vary in how much timer
|
|
// precision they expose to the OS, so id might be rounded
|
|
// e.g. to the same millisecond)
|
|
id++
|
|
}
|
|
lastSessionID = id
|
|
go func(sessionID int64) {
|
|
if err := p.serve(sessionID, c); err != nil {
|
|
log.Printf("%d: session ended with error: %v", sessionID, err)
|
|
}
|
|
}(id)
|
|
}
|
|
}
|
|
|
|
var (
|
|
// sslStart is the magic bytes that postgres clients use to indicate
|
|
// that they want to do a TLS handshake. Servers should respond with
|
|
// the single byte "S" before starting a normal TLS handshake.
|
|
sslStart = [8]byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f}
|
|
// plaintextStart is the magic bytes that postgres clients use to
|
|
// indicate that they're starting a plaintext authentication
|
|
// handshake.
|
|
plaintextStart = [8]byte{0, 0, 0, 86, 0, 3, 0, 0}
|
|
)
|
|
|
|
// serve proxies the postgres client on c to the proxy's upstream,
|
|
// enforcing strict TLS to the upstream.
|
|
func (p *proxy) serve(sessionID int64, c net.Conn) error {
|
|
defer c.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
whois, err := p.client.WhoIs(ctx, c.RemoteAddr().String())
|
|
if err != nil {
|
|
p.errors.Add("whois-failed", 1)
|
|
return fmt.Errorf("getting client identity: %v", err)
|
|
}
|
|
|
|
// Before anything else, log the connection attempt.
|
|
user, machine := "", ""
|
|
if whois.Node != nil {
|
|
if whois.Node.Hostinfo.ShareeNode() {
|
|
machine = "external-device"
|
|
} else {
|
|
machine = strings.TrimSuffix(whois.Node.Name, ".")
|
|
}
|
|
}
|
|
if whois.UserProfile != nil {
|
|
user = whois.UserProfile.LoginName
|
|
if user == "tagged-devices" && whois.Node != nil {
|
|
user = strings.Join(whois.Node.Tags, ",")
|
|
}
|
|
}
|
|
if user == "" || machine == "" {
|
|
p.errors.Add("no-ts-identity", 1)
|
|
return fmt.Errorf("couldn't identify source user and machine (user %q, machine %q)", user, machine)
|
|
}
|
|
log.Printf("%d: session start, from %s (machine %s, user %s)", sessionID, c.RemoteAddr(), machine, user)
|
|
start := time.Now()
|
|
defer func() {
|
|
elapsed := time.Since(start)
|
|
log.Printf("%d: session end, from %s (machine %s, user %s), lasted %s", sessionID, c.RemoteAddr(), machine, user, elapsed.Round(time.Millisecond))
|
|
}()
|
|
|
|
// Read the client's opening message, to figure out if it's trying
|
|
// to TLS or not.
|
|
var buf [8]byte
|
|
if _, err := io.ReadFull(c, buf[:len(sslStart)]); err != nil {
|
|
p.errors.Add("network-error", 1)
|
|
return fmt.Errorf("initial magic read: %v", err)
|
|
}
|
|
var clientIsTLS bool
|
|
switch {
|
|
case buf == sslStart:
|
|
clientIsTLS = true
|
|
case buf == plaintextStart:
|
|
clientIsTLS = false
|
|
default:
|
|
p.errors.Add("client-bad-protocol", 1)
|
|
return fmt.Errorf("unrecognized initial packet = % 02x", buf)
|
|
}
|
|
|
|
// Dial & verify upstream connection.
|
|
var d net.Dialer
|
|
d.Timeout = 10 * time.Second
|
|
upc, err := d.Dial("tcp", p.upstreamAddr)
|
|
if err != nil {
|
|
p.errors.Add("network-error", 1)
|
|
return fmt.Errorf("upstream dial: %v", err)
|
|
}
|
|
defer upc.Close()
|
|
if _, err := upc.Write(sslStart[:]); err != nil {
|
|
p.errors.Add("network-error", 1)
|
|
return fmt.Errorf("upstream write of start-ssl magic: %v", err)
|
|
}
|
|
if _, err := io.ReadFull(upc, buf[:1]); err != nil {
|
|
p.errors.Add("network-error", 1)
|
|
return fmt.Errorf("reading upstream start-ssl response: %v", err)
|
|
}
|
|
if buf[0] != 'S' {
|
|
p.errors.Add("upstream-bad-protocol", 1)
|
|
return fmt.Errorf("upstream didn't acknowledge start-ssl, said %q", buf[0])
|
|
}
|
|
tlsConf := &tls.Config{
|
|
ServerName: p.upstreamHost,
|
|
RootCAs: p.upstreamCertPool,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
uptc := tls.Client(upc, tlsConf)
|
|
if err = uptc.HandshakeContext(ctx); err != nil {
|
|
p.errors.Add("upstream-tls", 1)
|
|
return fmt.Errorf("upstream TLS handshake: %v", err)
|
|
}
|
|
|
|
// Accept the client conn and set it up the way the client wants.
|
|
var clientConn net.Conn
|
|
if clientIsTLS {
|
|
io.WriteString(c, "S") // yeah, we're good to speak TLS
|
|
s := tls.Server(c, &tls.Config{
|
|
ServerName: p.upstreamHost,
|
|
Certificates: p.downstreamCert,
|
|
MinVersion: tls.VersionTLS12,
|
|
})
|
|
if err = uptc.HandshakeContext(ctx); err != nil {
|
|
p.errors.Add("client-tls", 1)
|
|
return fmt.Errorf("client TLS handshake: %v", err)
|
|
}
|
|
clientConn = s
|
|
} else {
|
|
// Repeat the header we read earlier up to the server.
|
|
if _, err := uptc.Write(plaintextStart[:]); err != nil {
|
|
p.errors.Add("network-error", 1)
|
|
return fmt.Errorf("sending initial client bytes to upstream: %v", err)
|
|
}
|
|
clientConn = c
|
|
}
|
|
|
|
// Finally, proxy the client to the upstream.
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
_, err := io.Copy(uptc, clientConn)
|
|
errc <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(clientConn, uptc)
|
|
errc <- err
|
|
}()
|
|
if err := <-errc; err != nil {
|
|
// Don't increment error counts here, because the most common
|
|
// cause of termination is client or server closing the
|
|
// connection normally, and it'll obscure "interesting"
|
|
// handshake errors.
|
|
return fmt.Errorf("session terminated with error: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// mkSelfSigned creates and returns a self-signed TLS certificate for
|
|
// hostname.
|
|
func mkSelfSigned(hostname string) (tls.Certificate, error) {
|
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
|
|
if err != nil {
|
|
return tls.Certificate{}, err
|
|
}
|
|
pub := priv.Public()
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"pgproxy"},
|
|
},
|
|
DNSNames: []string{hostname},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, pub, priv)
|
|
if err != nil {
|
|
return tls.Certificate{}, err
|
|
}
|
|
cert, err := x509.ParseCertificate(derBytes)
|
|
if err != nil {
|
|
return tls.Certificate{}, err
|
|
}
|
|
|
|
return tls.Certificate{
|
|
Certificate: [][]byte{derBytes},
|
|
PrivateKey: priv,
|
|
Leaf: cert,
|
|
}, nil
|
|
}
|