diff --git a/control/controlclient/auto_test.go b/control/controlclient/auto_test.go index e86995e1d..43d5f8180 100644 --- a/control/controlclient/auto_test.go +++ b/control/controlclient/auto_test.go @@ -7,12 +7,14 @@ package controlclient import ( + "bytes" "context" "encoding/json" "fmt" "io/ioutil" "log" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" "os" @@ -310,25 +312,7 @@ func TestControl(t *testing.T) { c1.Login(nil, LoginInteractive) status := c1.waitStatus(t, stateURLVisitRequired) - authURL := status.New.URL - - resp, err := c1.httpc.Get(authURL) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != 200 { - t.Errorf("GET %s failed: %q", authURL, resp.Status) - } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - t.Fatal(err) - } - cookies := resp.Cookies() - if len(cookies) == 0 || cookies[0].Name != "tailcontrol" { - t.Logf("GET %s: %s", authURL, string(body)) - t.Fatalf("GET %s: bad cookie: %v", authURL, cookies) - } + c1.postAuthURL(t, "testuser1@tailscale.com", status.New) c1.waitStatus(t, stateAuthenticated) status = c1.waitStatus(t, stateSynchronized) if status.New.Err != "" { @@ -443,8 +427,8 @@ func TestLoginInterrupt(t *testing.T) { t.Errorf("auth URLs match for subsequent logins: %s", authURL) } - form := url.Values{"user": []string{loginName}} - req, err := http.NewRequest("POST", authURL2, strings.NewReader(form.Encode())) + // Direct auth URL visit is not enough because our cookie is no longer fresh. + req, err := http.NewRequest("GET", authURL2, nil) if err != nil { t.Fatal(err) } @@ -453,10 +437,37 @@ func TestLoginInterrupt(t *testing.T) { if err != nil { t.Fatal(err) } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if i := bytes.Index(b, []byte(" header + b = b[i:] + } + if !bytes.Contains(b, []byte(" header + b = b[i:] + } got = string(b) if !strings.Contains(got, "This is a new machine") { t.Fatalf("no machine authorization message:\n\n%s", got) @@ -830,11 +845,15 @@ func TestGoogleSigninButton(t *testing.T) { if err != nil { t.Fatal(err) } + if i := bytes.Index(b, []byte(" header + b = b[i:] + } got := string(b) if !strings.Contains(got, `Sign in with Google`) { t.Fatalf("page does not mention google signin button:\n\n%s", got) } + authURL = authURLForPOST(authURL) resp, err = c.httpc.PostForm(authURL, url.Values{"provider": []string{"google"}}) if err != nil { t.Fatal(err) @@ -846,6 +865,9 @@ func TestGoogleSigninButton(t *testing.T) { if err != nil { t.Fatal(err) } + if i := bytes.Index(b, []byte(" header + b = b[i:] + } got = string(b) if !strings.Contains(got, "Authorization successful") { t.Fatalf("no machine authorization message:\n\n%s", got) @@ -990,6 +1012,11 @@ func (s *server) newClient(t *testing.T, name string) *client { ch := make(chan statusChange, 1024) httpc := s.http.Client() + var err error + httpc.Jar, err = cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } hi := NewHostinfo() hi.FrontendLogID = "go-test-only" hi.BackendLogID = "go-test-only" @@ -1063,9 +1090,18 @@ func (c *client) postAuthURL(t *testing.T, user string, status Status) *http.Coo return postAuthURL(t, c.ctx, c.httpc, user, authURL) } +func authURLForPOST(authURL string) string { + i := strings.Index(authURL, "/a/") + if i == -1 { + panic("bad authURL: " + authURL) + } + return authURL[:i] + "/login?refresh=true&next_url=" + url.PathEscape(authURL[i:]) +} + func postAuthURL(t *testing.T, ctx context.Context, httpc *http.Client, user string, authURL string) *http.Cookie { t.Helper() + authURL = authURLForPOST(authURL) form := url.Values{"user": []string{user}} req, err := http.NewRequest("POST", authURL, strings.NewReader(form.Encode())) if err != nil { @@ -1076,12 +1112,21 @@ func postAuthURL(t *testing.T, ctx context.Context, httpc *http.Client, user str if err != nil { t.Fatal(err) } - if resp.StatusCode != 200 { - t.Fatalf("POST %s failed: %q", authURL, resp.Status) + b, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if i := bytes.Index(b, []byte(" header + b = b[i:] } - cookies := resp.Cookies() + if resp.StatusCode != 200 { + t.Fatalf("POST %s failed: %q, body: %s", authURL, resp.Status, b) + } + cookieURL, err := url.Parse(authURL) + if err != nil { + t.Fatal(err) + } + cookies := httpc.Jar.Cookies(cookieURL) if len(cookies) == 0 || cookies[0].Name != "tailcontrol" { - t.Fatalf("POST %s: bad cookie: %v", authURL, cookies) + t.Fatalf("POST %s: bad cookie: %v, body: %s", authURL, cookies, b) } return cookies[0] } diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index 799b72c75..1b3d7ea19 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -10,6 +10,7 @@ import ( "context" "io/ioutil" "net/http" + "net/http/cookiejar" "net/http/httptest" "os" "testing" @@ -30,6 +31,11 @@ func TestClientsReusingKeys(t *testing.T) { httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) })) + httpc := httpsrv.Client() + httpc.Jar, err = cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } server, err = control.New(tmpdir, tmpdir, httpsrv.URL, true) if err != nil { t.Fatal(err) @@ -64,7 +70,7 @@ func TestClientsReusingKeys(t *testing.T) { t.Fatal(err) } const user = "testuser1@tailscale.onmicrosoft.com" - postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + postAuthURL(t, ctx, httpc, user, authURL) newURL, err := c1.WaitLoginURL(ctx, authURL) if err != nil { t.Fatal(err) @@ -132,6 +138,11 @@ func TestClientsReusingOldKey(t *testing.T) { httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) })) + httpc := httpsrv.Client() + httpc.Jar, err = cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } server, err = control.New(tmpdir, tmpdir, httpsrv.URL, true) if err != nil { t.Fatal(err) @@ -149,7 +160,7 @@ func TestClientsReusingOldKey(t *testing.T) { genOpts := func() Options { return Options{ ServerURL: httpsrv.URL, - HTTPC: httpsrv.Client(), + HTTPC: httpc, //TimeNow: s.control.TimeNow, Logf: func(fmt string, args ...interface{}) { t.Helper() @@ -171,7 +182,7 @@ func TestClientsReusingOldKey(t *testing.T) { t.Fatal(err) } const user = "testuser1@tailscale.onmicrosoft.com" - postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + postAuthURL(t, ctx, httpc, user, authURL) newURL, err := c1.WaitLoginURL(ctx, authURL) if err != nil { t.Fatal(err) @@ -212,7 +223,7 @@ func TestClientsReusingOldKey(t *testing.T) { } else if authURL == "" { t.Fatal("expected authURL for reused oldNodeKey, got none") } else { - postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + postAuthURL(t, ctx, httpc, user, authURL) if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { t.Fatal(err) } else if newURL != "" { @@ -245,7 +256,7 @@ func TestClientsReusingOldKey(t *testing.T) { } else if authURL == "" { t.Fatal("expected authURL for reused oldNodeKey, got none") } else { - postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + postAuthURL(t, ctx, httpc, user, authURL) if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { t.Fatal(err) } else if newURL != "" { @@ -287,7 +298,7 @@ func TestClientsReusingOldKey(t *testing.T) { } else if authURL == "" { t.Fatal("expected authURL for reused nodeKey, got none") } else { - postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + postAuthURL(t, ctx, httpc, user, authURL) if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { t.Fatal(err) } else if newURL != "" { diff --git a/ipn/e2e_test.go b/ipn/e2e_test.go index a9a7202cd..14aa9e6e2 100644 --- a/ipn/e2e_test.go +++ b/ipn/e2e_test.go @@ -10,7 +10,10 @@ import ( "bytes" "io/ioutil" "net/http" + "net/http/cookiejar" "net/http/httptest" + "net/url" + "strings" "testing" "time" @@ -177,6 +180,13 @@ func newNode(t *testing.T, prefix string, https *httptest.Server) testNode { t.Logf(prefix+": "+fmt, args...) } + var err error + httpc := https.Client() + httpc.Jar, err = cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } + tun := tuntest.NewChannelTUN() e1, err := wgengine.NewUserspaceEngineAdvanced(logfe, tun.TUN(), wgengine.NewFakeRouter, 0) if err != nil { @@ -200,10 +210,23 @@ func newNode(t *testing.T, prefix string, https *httptest.Server) testNode { Notify: func(n Notify) { // Automatically visit auth URLs if n.BrowseToURL != nil { - t.Logf("\n\n\nURL! %vv\n", *n.BrowseToURL) - hc := https.Client() - _, err := hc.Get(*n.BrowseToURL) + t.Logf("BrowseToURL: %v", *n.BrowseToURL) + + authURL := *n.BrowseToURL + i := strings.Index(authURL, "/a/") + if i == -1 { + panic("bad authURL: " + authURL) + } + authURL = authURL[:i] + "/login?refresh=true&next_url=" + url.PathEscape(authURL[i:]) + + form := url.Values{"user": []string{c.LoginName}} + req, err := http.NewRequest("POST", authURL, strings.NewReader(form.Encode())) if err != nil { + t.Fatal(err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + if _, err := httpc.Do(req); err != nil { t.Logf("BrowseToURL: %v\n", err) } }