diff --git a/cmd/sniproxy/snipproxy.go b/cmd/sniproxy/snipproxy.go index 0ef212678..d465c690a 100644 --- a/cmd/sniproxy/snipproxy.go +++ b/cmd/sniproxy/snipproxy.go @@ -18,6 +18,7 @@ import ( "tailscale.com/client/tailscale" "tailscale.com/net/netutil" "tailscale.com/tsnet" + "tailscale.com/types/nettype" ) var ports = flag.String("ports", "443", "comma-separated list of ports to proxy") @@ -45,6 +46,13 @@ func main() { log.Printf("Serving on port %v ...", portStr) go s.serve(ln) } + + ln, err := s.ts.Listen("udp", ":53") + if err != nil { + log.Fatal(err) + } + go s.serveDNS(ln) + select {} } @@ -63,6 +71,25 @@ func (s *server) serve(ln net.Listener) { } } +func (s *server) serveDNS(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go s.serveDNSConn(c.(nettype.ConnPacketConn)) + } +} + +func (s *server) serveDNSConn(c nettype.ConnPacketConn) { + defer c.Close() + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 1500) + n, err := c.Read(buf) + log.Printf("got DNS packet: %q, %v", buf[:n], err) + // TODO: rest of the owl +} + func (s *server) serveConn(c net.Conn) { addrPortStr := c.LocalAddr().String() _, port, err := net.SplitHostPort(addrPortStr) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 96238e6a9..80746e053 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -564,27 +564,73 @@ func (s *Server) printAuthURLLoop() { } } -func (s *Server) forwardTCP(c net.Conn, port uint16) { +// networkForFamily returns one of "tcp4", "tcp6", "udp4", or "udp6". +// +// netBase is "tcp" or "udp" (without any '4' or '6' suffix). +func networkForFamily(netBase string, is6 bool) string { + switch netBase { + case "tcp": + if is6 { + return "tcp6" + } + return "tcp4" + case "udp": + if is6 { + return "udp6" + } + return "udp4" + } + panic("unexpected") +} + +// listenerForDstAddr returns a listener for the provided network and +// destination IP/port. It matches from most specific to least specific. +// For example: +// +// - ("tcp4", IP, port) +// - ("tcp", IP, port) +// - ("tcp4", "", port) +// - ("tcp", "", port) +// +// The netBase is "tcp" or "udp" (without any '4' or '6' suffix). +func (s *Server) listenerForDstAddr(netBase string, dst netip.AddrPort) (_ *listener, ok bool) { s.mu.Lock() - ln, ok := s.listeners[listenKey{"tcp", "", port}] - s.mu.Unlock() + defer s.mu.Unlock() + for _, a := range [2]netip.Addr{0: dst.Addr()} { + for _, net := range [2]string{ + networkForFamily(netBase, dst.Addr().Is6()), + netBase, + } { + if ln, ok := s.listeners[listenKey{net, a, dst.Port()}]; ok { + return ln, true + } + } + } + return nil, false +} + +func (s *Server) forwardTCP(c net.Conn, port uint16) { + dstStr := c.LocalAddr().String() + ap, err := netip.ParseAddrPort(dstStr) + if err != nil { + s.logf("unexpected dst addr %q", dstStr) + c.Close() + return + } + ln, ok := s.listenerForDstAddr("tcp", ap) if !ok { c.Close() return } - t := time.NewTimer(time.Second) - defer t.Stop() - select { - case ln.conn <- c: - case <-t.C: - c.Close() - } + ln.handle(c) } func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool) { - s.logf("rejecting incoming UDP flow: (%v, %v)", src, dst) - // TODO(bradfitz): hook up to Listen("udp", dst) so users of tsnet can hook into this. - return nil, true + ln, ok := s.listenerForDstAddr("udp", dst) + if !ok { + return nil, true // don't handle, don't forward to localhost + } + return func(c nettype.ConnPacketConn) { ln.handle(c) }, true } // getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog) @@ -650,7 +696,7 @@ func (s *Server) APIClient() (*tailscale.Client, error) { // It will start the server if it has not been started yet. func (s *Server) Listen(network, addr string) (net.Listener, error) { switch network { - case "", "tcp", "tcp4", "tcp6": + case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": default: return nil, errors.New("unsupported network type") } @@ -660,13 +706,30 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { } port, err := net.LookupPort(network, portStr) if err != nil || port < 0 || port > math.MaxUint16 { + // LookupPort returns an error on out of range values so the bounds + // checks on port should be unnecessary, but harmless. If they do + // match, worst case this error message says "invalid port: ". return nil, fmt.Errorf("invalid port: %w", err) } + var bindHostOrZero netip.Addr + if host != "" { + bindHostOrZero, err = netip.ParseAddr(host) + if err != nil { + return nil, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host) + } + if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() { + return nil, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network) + } + if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() { + return nil, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network) + } + } + if err := s.Start(); err != nil { return nil, err } - key := listenKey{network, host, uint16(port)} + key := listenKey{network, bindHostOrZero, uint16(port)} ln := &listener{ s: s, key: key, @@ -686,7 +749,7 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { type listenKey struct { network string - host string + host netip.Addr // or zero value for unspecified port uint16 } @@ -716,6 +779,18 @@ func (ln *listener) Close() error { return nil } +func (ln *listener) handle(c net.Conn) { + t := time.NewTimer(time.Second) + defer t.Stop() + select { + case ln.conn <- c: + case <-t.C: + // TODO(bradfitz): this isn't ideal. Think about how + // we how we want to do pushback. + c.Close() + } +} + // Server returns the tsnet Server associated with the listener. func (ln *listener) Server() *Server { return ln.s } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 3847b4358..2058162f1 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -49,6 +49,18 @@ func TestListenerPort(t *testing.T) { {"tcp", ":https", false}, // built-in name to Go; doesn't require cgo, /etc/services {"tcp", ":gibberishsdlkfj", true}, {"tcp", ":%!d(string=80)", true}, // issue 6201 + {"udp", ":80", false}, + {"udp", "100.102.104.108:80", false}, + {"udp", "not-an-ip:80", true}, + {"udp4", ":80", false}, + {"udp4", "100.102.104.108:80", false}, + {"udp4", "not-an-ip:80", true}, + + // Verify network type matches IP + {"tcp4", "1.2.3.4:80", false}, + {"tcp6", "1.2.3.4:80", true}, + {"tcp4", "[12::34]:80", true}, + {"tcp6", "[12::34]:80", false}, } for _, tt := range tests { s := &Server{}