parent
d756622432
commit
707acd18d0
|
@ -0,0 +1,107 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/inetaf/tcpproxy"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tsnet"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
type proxyGrantRule struct {
|
||||
AllowedHosts []dnsname.FQDN
|
||||
}
|
||||
|
||||
func handleConn(ctx context.Context, c net.Conn, lc *tailscale.LocalClient, dialCtx func(context.Context, string, string) (net.Conn, error)) {
|
||||
addrPortStr := c.LocalAddr().String()
|
||||
_, port, err := net.SplitHostPort(addrPortStr)
|
||||
if err != nil {
|
||||
log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
who, err := lc.WhoIs(ctx, c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
c.Close()
|
||||
log.Printf("tcpSNIHandler.Handle: WhoIs: %v", err)
|
||||
return
|
||||
}
|
||||
rules, err := tailcfg.UnmarshalCapJSON[proxyGrantRule](who.CapMap, "maisem.com/tailproxy")
|
||||
if err != nil {
|
||||
c.Close()
|
||||
log.Printf("tcpSNIHandler.Handle: UnmarshalCapJSON: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var p tcpproxy.Proxy
|
||||
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
|
||||
return netutil.NewOneConnListener(c, nil), nil
|
||||
}
|
||||
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
|
||||
sniFQDN, err := dnsname.ToFQDN(sniName)
|
||||
if err != nil {
|
||||
log.Printf("tcpSNIHandler.Handle: ToFQDN: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if slices.ContainsFunc(rule.AllowedHosts, func(fqdn dnsname.FQDN) bool {
|
||||
return fqdn == "*" || fqdn.Contains(sniFQDN)
|
||||
}) {
|
||||
log.Printf("tcpSNIHandler.Handle: %s is allowed", sniName)
|
||||
return &tcpproxy.DialProxy{
|
||||
Addr: net.JoinHostPort(sniName, port),
|
||||
DialContext: dialCtx,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
log.Printf("tcpSNIHandler.Handle: %s is not allowed", sniName)
|
||||
return nil, false
|
||||
})
|
||||
p.Start()
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
||||
hostname = flag.String("hostname", "", "Hostname to register the service under")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
s := &tsnet.Server{
|
||||
Hostname: *hostname,
|
||||
Logf: logger.Discard,
|
||||
}
|
||||
must.Get(s.Up(ctx))
|
||||
var wg sync.WaitGroup
|
||||
log.Printf("Listening on ports: %s", *ports)
|
||||
for _, port := range strings.Split(*ports, ",") {
|
||||
wg.Add(1)
|
||||
ln := must.Get(s.Listen("tcp", ":"+port))
|
||||
lc := must.Get(s.LocalClient())
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fmt.Println("Accepted connection")
|
||||
go handleConn(ctx, c, lc, s.Dial)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
Loading…
Reference in New Issue