2023-03-02 22:02:37 +00:00
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// The sniproxy is an outbound SNI proxy. It receives TLS connections over
// Tailscale on one or more TCP ports and sends them out to the same SNI
2023-08-07 15:51:47 +01:00
// hostname & port on the internet. It can optionally forward one or more
// TCP ports to a specific destination. It only does TCP.
2023-03-02 22:02:37 +00:00
package main
import (
"context"
2023-08-07 15:51:47 +01:00
"errors"
"expvar"
2023-03-02 22:02:37 +00:00
"flag"
2023-08-07 15:51:47 +01:00
"fmt"
2023-03-02 22:02:37 +00:00
"log"
"net"
2023-03-07 16:46:02 +00:00
"net/http"
2023-08-25 15:27:15 +01:00
"os"
2023-08-07 15:51:47 +01:00
"strconv"
2023-03-02 22:02:37 +00:00
"strings"
"time"
2023-08-25 15:27:15 +01:00
"github.com/peterbourgon/ff/v3"
2023-03-06 00:40:15 +00:00
"golang.org/x/net/dns/dnsmessage"
2023-03-02 22:02:37 +00:00
"inet.af/tcpproxy"
"tailscale.com/client/tailscale"
2023-04-29 05:28:52 +01:00
"tailscale.com/hostinfo"
2023-08-07 15:51:47 +01:00
"tailscale.com/metrics"
2023-03-02 22:02:37 +00:00
"tailscale.com/net/netutil"
"tailscale.com/tsnet"
2023-08-07 15:51:47 +01:00
"tailscale.com/tsweb"
2023-03-05 20:13:36 +00:00
"tailscale.com/types/nettype"
2023-07-29 08:11:19 +01:00
"tailscale.com/util/clientmetric"
2023-03-02 22:02:37 +00:00
)
2023-03-06 00:40:15 +00:00
var tsMBox = dnsmessage . MustNewName ( "support.tailscale.com." )
2023-08-07 15:51:47 +01:00
// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
type portForward struct {
Port int
Proto string
Destination string
}
// parseForward takes a proto/port/destination tuple as an input, as would be passed
// to the --forward command line flag, and returns a *portForward struct of those parameters.
func parseForward ( value string ) ( * portForward , error ) {
parts := strings . Split ( value , "/" )
if len ( parts ) != 3 {
return nil , errors . New ( "cannot parse: " + value )
}
proto := parts [ 0 ]
if proto != "tcp" {
return nil , errors . New ( "unsupported forwarding protocol: " + proto )
}
port , err := strconv . ParseUint ( parts [ 1 ] , 10 , 16 )
if err != nil {
return nil , errors . New ( "bad forwarding port: " + parts [ 1 ] )
}
host := parts [ 2 ]
if host == "" {
return nil , errors . New ( "bad destination: " + value )
}
return & portForward { Port : int ( port ) , Proto : proto , Destination : host } , nil
}
2023-07-29 08:11:19 +01:00
2023-03-02 22:02:37 +00:00
func main ( ) {
2023-08-25 15:27:15 +01:00
fs := flag . NewFlagSet ( "sniproxy" , flag . ContinueOnError )
var (
ports = fs . String ( "ports" , "443" , "comma-separated list of ports to proxy" )
forwards = fs . String ( "forwards" , "" , "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com" )
wgPort = fs . Int ( "wg-listen-port" , 0 , "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select" )
promoteHTTPS = fs . Bool ( "promote-https" , true , "promote HTTP to HTTPS" )
debugPort = fs . Int ( "debug-port" , 8080 , "Listening port for debug/metrics endpoint" )
)
err := ff . Parse ( fs , os . Args [ 1 : ] , ff . WithEnvVarPrefix ( "TS_APPC" ) )
if err != nil {
log . Fatal ( "ff.Parse" )
}
2023-03-02 22:02:37 +00:00
if * ports == "" {
log . Fatal ( "no ports" )
}
2023-04-29 05:28:52 +01:00
hostinfo . SetApp ( "sniproxy" )
2023-03-02 22:02:37 +00:00
var s server
2023-06-26 21:57:44 +01:00
s . ts . Port = uint16 ( * wgPort )
2023-03-02 22:02:37 +00:00
defer s . ts . Close ( )
lc , err := s . ts . LocalClient ( )
if err != nil {
log . Fatal ( err )
}
s . lc = lc
2023-08-07 15:51:47 +01:00
s . initMetrics ( )
2023-03-02 22:02:37 +00:00
for _ , portStr := range strings . Split ( * ports , "," ) {
ln , err := s . ts . Listen ( "tcp" , ":" + portStr )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Serving on port %v ..." , portStr )
go s . serve ( ln )
}
2023-03-05 20:13:36 +00:00
2023-08-07 15:51:47 +01:00
for _ , forwStr := range strings . Split ( * forwards , "," ) {
if forwStr == "" {
continue
}
forw , err := parseForward ( forwStr )
if err != nil {
log . Fatal ( err )
}
ln , err := s . ts . Listen ( "tcp" , ":" + strconv . Itoa ( forw . Port ) )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Serving on port %d to %s..." , forw . Port , forw . Destination )
// Add an entry to the expvar LabelMap for Prometheus metrics,
// and create a clientmetric to report that same value.
service := portNumberToName ( forw )
s . numTCPsessions . SetInt64 ( service , 0 )
metric := fmt . Sprintf ( "sniproxy_tcp_sessions_%s" , service )
clientmetric . NewCounterFunc ( metric , func ( ) int64 {
return s . numTCPsessions . Get ( service ) . Value ( )
} )
go s . forward ( ln , forw )
}
2023-03-05 20:13:36 +00:00
ln , err := s . ts . Listen ( "udp" , ":53" )
if err != nil {
log . Fatal ( err )
}
go s . serveDNS ( ln )
2023-03-07 16:46:02 +00:00
if * promoteHTTPS {
ln , err := s . ts . Listen ( "tcp" , ":80" )
if err != nil {
log . Fatal ( err )
}
log . Printf ( "Promoting HTTP to HTTPS ..." )
go s . promoteHTTPS ( ln )
}
2023-08-07 15:51:47 +01:00
if * debugPort != 0 {
mux := http . NewServeMux ( )
tsweb . Debugger ( mux )
dln , err := s . ts . Listen ( "tcp" , fmt . Sprintf ( ":%d" , * debugPort ) )
if err != nil {
log . Fatal ( err )
}
go func ( ) {
log . Fatal ( http . Serve ( dln , mux ) )
} ( )
}
2023-03-02 22:02:37 +00:00
select { }
}
type server struct {
ts tsnet . Server
lc * tailscale . LocalClient
2023-08-07 15:51:47 +01:00
numTLSsessions expvar . Int
numTCPsessions * metrics . LabelMap
numBadAddrPort expvar . Int
dnsResponses expvar . Int
dnsFailures expvar . Int
httpPromoted expvar . Int
2023-03-02 22:02:37 +00:00
}
func ( s * server ) serve ( ln net . Listener ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Fatal ( err )
}
go s . serveConn ( c )
}
}
2023-08-07 15:51:47 +01:00
func ( s * server ) forward ( ln net . Listener , forw * portForward ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Fatal ( err )
}
go s . forwardConn ( c , forw )
}
}
2023-03-05 20:13:36 +00:00
func ( s * server ) serveDNS ( ln net . Listener ) {
for {
c , err := ln . Accept ( )
if err != nil {
log . Fatal ( err )
}
go s . serveDNSConn ( c . ( nettype . ConnPacketConn ) )
}
}
func ( s * server ) serveDNSConn ( c nettype . ConnPacketConn ) {
defer c . Close ( )
c . SetReadDeadline ( time . Now ( ) . Add ( 5 * time . Second ) )
buf := make ( [ ] byte , 1500 )
n , err := c . Read ( buf )
2023-03-06 00:40:15 +00:00
if err != nil {
log . Printf ( "c.Read failed: %v\n " , err )
2023-08-07 15:51:47 +01:00
s . dnsFailures . Add ( 1 )
2023-03-06 00:40:15 +00:00
return
}
var msg dnsmessage . Message
err = msg . Unpack ( buf [ : n ] )
if err != nil {
log . Printf ( "dnsmessage unpack failed: %v\n " , err )
2023-08-07 15:51:47 +01:00
s . dnsFailures . Add ( 1 )
2023-03-06 00:40:15 +00:00
return
}
buf , err = s . dnsResponse ( & msg )
if err != nil {
log . Printf ( "s.dnsResponse failed: %v\n" , err )
2023-08-07 15:51:47 +01:00
s . dnsFailures . Add ( 1 )
2023-03-06 00:40:15 +00:00
return
}
_ , err = c . Write ( buf )
if err != nil {
log . Printf ( "c.Write failed: %v\n" , err )
2023-08-07 15:51:47 +01:00
s . dnsFailures . Add ( 1 )
2023-03-06 00:40:15 +00:00
return
}
2023-07-29 08:11:19 +01:00
2023-08-07 15:51:47 +01:00
s . dnsResponses . Add ( 1 )
2023-03-05 20:13:36 +00:00
}
2023-03-02 22:02:37 +00:00
func ( s * server ) serveConn ( c net . Conn ) {
addrPortStr := c . LocalAddr ( ) . String ( )
_ , port , err := net . SplitHostPort ( addrPortStr )
if err != nil {
log . Printf ( "bogus addrPort %q" , addrPortStr )
2023-08-07 15:51:47 +01:00
s . numBadAddrPort . Add ( 1 )
2023-03-02 22:02:37 +00:00
c . Close ( )
return
}
var dialer net . Dialer
dialer . Timeout = 5 * time . Second
var p tcpproxy . Proxy
p . ListenFunc = func ( net , laddr string ) ( net . Listener , error ) {
return netutil . NewOneConnListener ( c , nil ) , nil
}
p . AddSNIRouteFunc ( addrPortStr , func ( ctx context . Context , sniName string ) ( t tcpproxy . Target , ok bool ) {
2023-08-07 15:51:47 +01:00
s . numTLSsessions . Add ( 1 )
2023-03-02 22:02:37 +00:00
return & tcpproxy . DialProxy {
Addr : net . JoinHostPort ( sniName , port ) ,
DialContext : dialer . DialContext ,
} , true
} )
p . Start ( )
}
2023-03-06 00:40:15 +00:00
2023-08-07 15:51:47 +01:00
// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
// and "tcp###" for everything else. It is used for metric label names.
func portNumberToName ( forw * portForward ) string {
switch forw . Port {
case 22 :
return "ssh"
case 1433 :
return "sqlserver"
case 3306 :
return "mysql"
case 3389 :
return "rdp"
case 5432 :
return "postgres"
default :
return fmt . Sprintf ( "%s%d" , forw . Proto , forw . Port )
}
}
// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
// like the SNI forwarding does, it merely forwards all data to the destination specified
// in the --forward=tcp/22/github.com argument.
func ( s * server ) forwardConn ( c net . Conn , forw * portForward ) {
addrPortStr := c . LocalAddr ( ) . String ( )
var dialer net . Dialer
dialer . Timeout = 30 * time . Second
var p tcpproxy . Proxy
p . ListenFunc = func ( net , laddr string ) ( net . Listener , error ) {
return netutil . NewOneConnListener ( c , nil ) , nil
}
dial := & tcpproxy . DialProxy {
Addr : fmt . Sprintf ( "%s:%d" , forw . Destination , forw . Port ) ,
DialContext : dialer . DialContext ,
}
p . AddRoute ( addrPortStr , dial )
s . numTCPsessions . Add ( portNumberToName ( forw ) , 1 )
p . Start ( )
}
2023-03-06 00:40:15 +00:00
func ( s * server ) dnsResponse ( req * dnsmessage . Message ) ( buf [ ] byte , err error ) {
resp := dnsmessage . NewBuilder ( buf ,
dnsmessage . Header {
ID : req . Header . ID ,
Response : true ,
Authoritative : true ,
} )
resp . EnableCompression ( )
if len ( req . Questions ) == 0 {
buf , _ = resp . Finish ( )
return
}
q := req . Questions [ 0 ]
err = resp . StartQuestions ( )
if err != nil {
return
}
resp . Question ( q )
ip4 , ip6 := s . ts . TailscaleIPs ( )
err = resp . StartAnswers ( )
if err != nil {
return
}
switch q . Type {
case dnsmessage . TypeAAAA :
err = resp . AAAAResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . AAAAResource { AAAA : ip6 . As16 ( ) } ,
)
case dnsmessage . TypeA :
err = resp . AResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . AResource { A : ip4 . As4 ( ) } ,
)
case dnsmessage . TypeSOA :
err = resp . SOAResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . SOAResource { NS : q . Name , MBox : tsMBox , Serial : 2023030600 ,
Refresh : 120 , Retry : 120 , Expire : 120 , MinTTL : 60 } ,
)
case dnsmessage . TypeNS :
err = resp . NSResource (
dnsmessage . ResourceHeader { Name : q . Name , Class : q . Class , TTL : 120 } ,
dnsmessage . NSResource { NS : tsMBox } ,
)
}
if err != nil {
return
}
return resp . Finish ( )
}
2023-03-07 16:46:02 +00:00
func ( s * server ) promoteHTTPS ( ln net . Listener ) {
err := http . Serve ( ln , http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
2023-08-07 15:51:47 +01:00
s . httpPromoted . Add ( 1 )
2023-03-07 16:46:02 +00:00
http . Redirect ( w , r , "https://" + r . Host + r . RequestURI , http . StatusFound )
} ) )
log . Fatalf ( "promoteHTTPS http.Serve: %v" , err )
}
2023-08-07 15:51:47 +01:00
// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
// same counters.
func ( s * server ) initMetrics ( ) {
stats := new ( metrics . Set )
stats . Set ( "tls_sessions" , & s . numTLSsessions )
clientmetric . NewCounterFunc ( "sniproxy_tls_sessions" , s . numTLSsessions . Value )
s . numTCPsessions = & metrics . LabelMap { Label : "proto" }
stats . Set ( "tcp_sessions" , s . numTCPsessions )
// clientmetric doesn't have a good way to implement a Map type.
// We create clientmetrics dynamically when parsing the --forwards argument
stats . Set ( "bad_addrport" , & s . numBadAddrPort )
clientmetric . NewCounterFunc ( "sniproxy_bad_addrport" , s . numBadAddrPort . Value )
stats . Set ( "dns_responses" , & s . dnsResponses )
clientmetric . NewCounterFunc ( "sniproxy_dns_responses" , s . dnsResponses . Value )
stats . Set ( "dns_failed" , & s . dnsFailures )
clientmetric . NewCounterFunc ( "sniproxy_dns_failed" , s . dnsFailures . Value )
stats . Set ( "http_promoted" , & s . httpPromoted )
clientmetric . NewCounterFunc ( "sniproxy_http_promoted" , s . httpPromoted . Value )
expvar . Publish ( "sniproxy" , stats )
}