diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index b6d99f34b..3da21ba1d 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -5,11 +5,17 @@ package cli import ( + "bytes" + "encoding/json" "flag" + "fmt" + "strings" "testing" "inet.af/netaddr" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/types/preftype" ) // Test that checkForAccidentalSettingReverts's updateMaskedPrefsFromUpFlag can handle @@ -129,3 +135,161 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { }) } } + +func TestPrefsFromUpArgs(t *testing.T) { + tests := []struct { + name string + args upArgsT + goos string // runtime.GOOS; empty means linux + st *ipnstate.Status // or nil + want *ipn.Prefs + wantErr string + wantWarn string + }{ + { + name: "zero", + goos: "windows", + args: upArgsT{}, + want: &ipn.Prefs{ + WantRunning: true, + NoSNAT: true, + NetfilterMode: preftype.NetfilterOn, // silly, but default from ipn.NewPref currently + }, + }, + { + name: "error_advertise_route_invalid_ip", + args: upArgsT{ + advertiseRoutes: "foo", + }, + wantErr: `"foo" is not a valid IP address or CIDR prefix`, + }, + { + name: "error_advertise_route_unmasked_bits", + args: upArgsT{ + advertiseRoutes: "1.2.3.4/16", + }, + wantErr: `1.2.3.4/16 has non-address bits set; expected 1.2.0.0/16`, + }, + { + name: "error_exit_node_bad_ip", + args: upArgsT{ + exitNodeIP: "foo", + }, + wantErr: `invalid IP address "foo" for --exit-node: unable to parse IP`, + }, + { + name: "error_exit_node_allow_lan_without_exit_node", + args: upArgsT{ + exitNodeAllowLANAccess: true, + }, + wantErr: `--exit-node-allow-lan-access can only be used with --exit-node`, + }, + { + name: "error_tag_prefix", + args: upArgsT{ + advertiseTags: "foo", + }, + wantErr: `tag: "foo": tags must start with 'tag:'`, + }, + { + name: "error_long_hostname", + args: upArgsT{ + hostname: strings.Repeat("a", 300), + }, + wantErr: `hostname too long: 300 bytes (max 256)`, + }, + { + name: "error_linux_netfilter_empty", + args: upArgsT{ + netfilterMode: "", + }, + wantErr: `invalid value --netfilter-mode=""`, + }, + { + name: "error_linux_netfilter_bogus", + args: upArgsT{ + netfilterMode: "bogus", + }, + wantErr: `invalid value --netfilter-mode="bogus"`, + }, + { + name: "error_exit_node_ip_is_self_ip", + args: upArgsT{ + exitNodeIP: "100.105.106.107", + }, + st: &ipnstate.Status{ + TailscaleIPs: []netaddr.IP{netaddr.MustParseIP("100.105.106.107")}, + }, + wantErr: `cannot use 100.105.106.107 as the exit node as it is a local IP address to this machine, did you mean --advertise-exit-node?`, + }, + { + name: "warn_linux_netfilter_nodivert", + goos: "linux", + args: upArgsT{ + netfilterMode: "nodivert", + }, + wantWarn: "netfilter=nodivert; add iptables calls to ts-* chains manually.", + want: &ipn.Prefs{ + WantRunning: true, + NetfilterMode: preftype.NetfilterNoDivert, + NoSNAT: true, + }, + }, + { + name: "warn_linux_netfilter_off", + goos: "linux", + args: upArgsT{ + netfilterMode: "off", + }, + wantWarn: "netfilter=off; configure iptables yourself.", + want: &ipn.Prefs{ + WantRunning: true, + NetfilterMode: preftype.NetfilterOff, + NoSNAT: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var warnBuf bytes.Buffer + warnf := func(format string, a ...interface{}) { + fmt.Fprintf(&warnBuf, format, a...) + } + goos := tt.goos + if goos == "" { + goos = "linux" + } + st := tt.st + if st == nil { + st = new(ipnstate.Status) + } + got, err := prefsFromUpArgs(tt.args, warnf, st, goos) + gotErr := fmt.Sprint(err) + if tt.wantErr != "" { + if tt.wantErr != gotErr { + t.Errorf("wrong error.\n got error: %v\nwant error: %v\n", gotErr, tt.wantErr) + } + return + } + if err != nil { + t.Fatal(err) + } + if tt.want == nil { + t.Fatal("tt.want is nil") + } + if !got.Equals(tt.want) { + jgot, _ := json.MarshalIndent(got, "", "\t") + jwant, _ := json.MarshalIndent(tt.want, "", "\t") + if bytes.Equal(jgot, jwant) { + t.Logf("prefs differ only in non-JSON-visible ways (nil/non-nil zero-length arrays)") + } + t.Errorf("wrong prefs\n got: %s\nwant: %s\n\ngot: %s\nwant: %s\n", + got.Pretty(), tt.want.Pretty(), + jgot, jwant, + ) + + } + }) + } + +} diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index b472fea41..a3aaa3cbe 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -22,7 +22,9 @@ import ( "inet.af/netaddr" "tailscale.com/client/tailscale" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/types/logger" "tailscale.com/types/preftype" "tailscale.com/version/distro" ) @@ -84,7 +86,7 @@ func defaultNetfilterMode() string { return "on" } -var upArgs struct { +type upArgsT struct { reset bool server string acceptRoutes bool @@ -104,6 +106,8 @@ var upArgs struct { hostname string } +var upArgs upArgsT + func warnf(format string, args ...interface{}) { fmt.Printf("Warning: "+format+"\n", args...) } @@ -113,6 +117,119 @@ var ( ipv6default = netaddr.MustParseIPPrefix("::/0") ) +// prefsFromUpArgs returns the ipn.Prefs for the provided args. +// +// Note that the parameters upArgs and warnf are named intentionally +// to shadow the globals to prevent accidental misuse of them. This +// function exists for testing and should have no side effects or +// outside interactions (e.g. no making Tailscale local API calls). +func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goos string) (*ipn.Prefs, error) { + routeMap := map[netaddr.IPPrefix]bool{} + var default4, default6 bool + if upArgs.advertiseRoutes != "" { + advroutes := strings.Split(upArgs.advertiseRoutes, ",") + for _, s := range advroutes { + ipp, err := netaddr.ParseIPPrefix(s) + if err != nil { + return nil, fmt.Errorf("%q is not a valid IP address or CIDR prefix", s) + } + if ipp != ipp.Masked() { + return nil, fmt.Errorf("%s has non-address bits set; expected %s", ipp, ipp.Masked()) + } + if ipp == ipv4default { + default4 = true + } else if ipp == ipv6default { + default6 = true + } + routeMap[ipp] = true + } + if default4 && !default6 { + return nil, fmt.Errorf("%s advertised without its IPv6 counterpart, please also advertise %s", ipv4default, ipv6default) + } else if default6 && !default4 { + return nil, fmt.Errorf("%s advertised without its IPv6 counterpart, please also advertise %s", ipv6default, ipv4default) + } + } + if upArgs.advertiseDefaultRoute { + routeMap[netaddr.MustParseIPPrefix("0.0.0.0/0")] = true + routeMap[netaddr.MustParseIPPrefix("::/0")] = true + } + routes := make([]netaddr.IPPrefix, 0, len(routeMap)) + for r := range routeMap { + routes = append(routes, r) + } + sort.Slice(routes, func(i, j int) bool { + if routes[i].Bits != routes[j].Bits { + return routes[i].Bits < routes[j].Bits + } + return routes[i].IP.Less(routes[j].IP) + }) + + var exitNodeIP netaddr.IP + if upArgs.exitNodeIP != "" { + var err error + exitNodeIP, err = netaddr.ParseIP(upArgs.exitNodeIP) + if err != nil { + return nil, fmt.Errorf("invalid IP address %q for --exit-node: %v", upArgs.exitNodeIP, err) + } + } else if upArgs.exitNodeAllowLANAccess { + return nil, fmt.Errorf("--exit-node-allow-lan-access can only be used with --exit-node") + } + + if upArgs.exitNodeIP != "" { + for _, ip := range st.TailscaleIPs { + if exitNodeIP == ip { + return nil, fmt.Errorf("cannot use %s as the exit node as it is a local IP address to this machine, did you mean --advertise-exit-node?", upArgs.exitNodeIP) + } + } + } + + var tags []string + if upArgs.advertiseTags != "" { + tags = strings.Split(upArgs.advertiseTags, ",") + for _, tag := range tags { + err := tailcfg.CheckTag(tag) + if err != nil { + return nil, fmt.Errorf("tag: %q: %s", tag, err) + } + } + } + + if len(upArgs.hostname) > 256 { + return nil, fmt.Errorf("hostname too long: %d bytes (max 256)", len(upArgs.hostname)) + } + + prefs := ipn.NewPrefs() + prefs.ControlURL = upArgs.server + prefs.WantRunning = true + prefs.RouteAll = upArgs.acceptRoutes + prefs.ExitNodeIP = exitNodeIP + prefs.ExitNodeAllowLANAccess = upArgs.exitNodeAllowLANAccess + prefs.CorpDNS = upArgs.acceptDNS + prefs.AllowSingleHosts = upArgs.singleRoutes + prefs.ShieldsUp = upArgs.shieldsUp + prefs.AdvertiseRoutes = routes + prefs.AdvertiseTags = tags + prefs.NoSNAT = !upArgs.snat + prefs.Hostname = upArgs.hostname + prefs.ForceDaemon = upArgs.forceDaemon + + if goos == "linux" { + switch upArgs.netfilterMode { + case "on": + prefs.NetfilterMode = preftype.NetfilterOn + case "nodivert": + prefs.NetfilterMode = preftype.NetfilterNoDivert + warnf("netfilter=nodivert; add iptables calls to ts-* chains manually.") + case "off": + prefs.NetfilterMode = preftype.NetfilterOff + warnf("netfilter=off; configure iptables yourself.") + default: + return nil, fmt.Errorf("invalid value --netfilter-mode=%q", upArgs.netfilterMode) + } + } + return prefs, nil +} + func runUp(ctx context.Context, args []string) error { if len(args) > 0 { fatalf("too many non-flag arguments: %q", args) @@ -136,114 +253,16 @@ func runUp(ctx context.Context, args []string) error { } } - routeMap := map[netaddr.IPPrefix]bool{} - var default4, default6 bool - if upArgs.advertiseRoutes != "" { - advroutes := strings.Split(upArgs.advertiseRoutes, ",") - for _, s := range advroutes { - ipp, err := netaddr.ParseIPPrefix(s) - if err != nil { - fatalf("%q is not a valid IP address or CIDR prefix", s) - } - if ipp != ipp.Masked() { - fatalf("%s has non-address bits set; expected %s", ipp, ipp.Masked()) - } - if ipp == ipv4default { - default4 = true - } else if ipp == ipv6default { - default6 = true - } - routeMap[ipp] = true - } - if default4 && !default6 { - fatalf("%s advertised without its IPv6 counterpart, please also advertise %s", ipv4default, ipv6default) - } else if default6 && !default4 { - fatalf("%s advertised without its IPv6 counterpart, please also advertise %s", ipv6default, ipv4default) - } + prefs, err := prefsFromUpArgs(upArgs, warnf, st, runtime.GOOS) + if err != nil { + fatalf("%s", err) } - if upArgs.advertiseDefaultRoute { - routeMap[netaddr.MustParseIPPrefix("0.0.0.0/0")] = true - routeMap[netaddr.MustParseIPPrefix("::/0")] = true - } - if len(routeMap) > 0 { + + if len(prefs.AdvertiseRoutes) > 0 { if err := tailscale.CheckIPForwarding(context.Background()); err != nil { warnf("%v", err) } } - routes := make([]netaddr.IPPrefix, 0, len(routeMap)) - for r := range routeMap { - routes = append(routes, r) - } - sort.Slice(routes, func(i, j int) bool { - if routes[i].Bits != routes[j].Bits { - return routes[i].Bits < routes[j].Bits - } - return routes[i].IP.Less(routes[j].IP) - }) - - var exitNodeIP netaddr.IP - if upArgs.exitNodeIP != "" { - var err error - exitNodeIP, err = netaddr.ParseIP(upArgs.exitNodeIP) - if err != nil { - fatalf("invalid IP address %q for --exit-node: %v", upArgs.exitNodeIP, err) - } - } else if upArgs.exitNodeAllowLANAccess { - fatalf("--exit-node-allow-lan-access can only be used with --exit-node") - } - - if !exitNodeIP.IsZero() { - for _, ip := range st.TailscaleIPs { - if exitNodeIP == ip { - fatalf("cannot use %s as the exit node as it is a local IP address to this machine, did you mean --advertise-exit-node?", exitNodeIP) - } - } - } - - var tags []string - if upArgs.advertiseTags != "" { - tags = strings.Split(upArgs.advertiseTags, ",") - for _, tag := range tags { - err := tailcfg.CheckTag(tag) - if err != nil { - fatalf("tag: %q: %s", tag, err) - } - } - } - - if len(upArgs.hostname) > 256 { - fatalf("hostname too long: %d bytes (max 256)", len(upArgs.hostname)) - } - - prefs := ipn.NewPrefs() - prefs.ControlURL = upArgs.server - prefs.WantRunning = true - prefs.RouteAll = upArgs.acceptRoutes - prefs.ExitNodeIP = exitNodeIP - prefs.ExitNodeAllowLANAccess = upArgs.exitNodeAllowLANAccess - prefs.CorpDNS = upArgs.acceptDNS - prefs.AllowSingleHosts = upArgs.singleRoutes - prefs.ShieldsUp = upArgs.shieldsUp - prefs.AdvertiseRoutes = routes - prefs.AdvertiseTags = tags - prefs.NoSNAT = !upArgs.snat - prefs.Hostname = upArgs.hostname - prefs.ForceDaemon = upArgs.forceDaemon - - if runtime.GOOS == "linux" { - switch upArgs.netfilterMode { - case "on": - prefs.NetfilterMode = preftype.NetfilterOn - case "nodivert": - prefs.NetfilterMode = preftype.NetfilterNoDivert - warnf("netfilter=nodivert; add iptables calls to ts-* chains manually.") - case "off": - prefs.NetfilterMode = preftype.NetfilterOff - warnf("netfilter=off; configure iptables yourself.") - default: - fatalf("invalid value --netfilter-mode: %q", upArgs.netfilterMode) - } - } curPrefs, err := tailscale.GetPrefs(ctx) if err != nil {