Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-03-05 09:28:24 -08:00
parent d756622432
commit 707acd18d0
1 changed files with 107 additions and 0 deletions

107
cmd/tailproxy/proxy.go Normal file
View File

@ -0,0 +1,107 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net"
"slices"
"strings"
"sync"
"github.com/inetaf/tcpproxy"
"tailscale.com/client/tailscale"
"tailscale.com/net/netutil"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/must"
)
type proxyGrantRule struct {
AllowedHosts []dnsname.FQDN
}
func handleConn(ctx context.Context, c net.Conn, lc *tailscale.LocalClient, dialCtx func(context.Context, string, string) (net.Conn, error)) {
addrPortStr := c.LocalAddr().String()
_, port, err := net.SplitHostPort(addrPortStr)
if err != nil {
log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr)
c.Close()
return
}
who, err := lc.WhoIs(ctx, c.RemoteAddr().String())
if err != nil {
c.Close()
log.Printf("tcpSNIHandler.Handle: WhoIs: %v", err)
return
}
rules, err := tailcfg.UnmarshalCapJSON[proxyGrantRule](who.CapMap, "maisem.com/tailproxy")
if err != nil {
c.Close()
log.Printf("tcpSNIHandler.Handle: UnmarshalCapJSON: %v", err)
return
}
var p tcpproxy.Proxy
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
return netutil.NewOneConnListener(c, nil), nil
}
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
sniFQDN, err := dnsname.ToFQDN(sniName)
if err != nil {
log.Printf("tcpSNIHandler.Handle: ToFQDN: %v", err)
return nil, false
}
for _, rule := range rules {
if slices.ContainsFunc(rule.AllowedHosts, func(fqdn dnsname.FQDN) bool {
return fqdn == "*" || fqdn.Contains(sniFQDN)
}) {
log.Printf("tcpSNIHandler.Handle: %s is allowed", sniName)
return &tcpproxy.DialProxy{
Addr: net.JoinHostPort(sniName, port),
DialContext: dialCtx,
}, true
}
}
log.Printf("tcpSNIHandler.Handle: %s is not allowed", sniName)
return nil, false
})
p.Start()
}
func main() {
var (
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
hostname = flag.String("hostname", "", "Hostname to register the service under")
)
flag.Parse()
ctx := context.Background()
s := &tsnet.Server{
Hostname: *hostname,
Logf: logger.Discard,
}
must.Get(s.Up(ctx))
var wg sync.WaitGroup
log.Printf("Listening on ports: %s", *ports)
for _, port := range strings.Split(*ports, ",") {
wg.Add(1)
ln := must.Get(s.Listen("tcp", ":"+port))
lc := must.Get(s.LocalClient())
go func() {
defer wg.Done()
for {
c, err := ln.Accept()
if err != nil {
continue
}
fmt.Println("Accepted connection")
go handleConn(ctx, c, lc, s.Dial)
}
}()
}
wg.Wait()
}