net/netcheck: fix HTTPS fallback bug from earlier today
My earlier 3fa58303d0
tried to implement
the net/http.Tranhsport.DialTLSContext hook, but I didn't return a
*tls.Conn, so we ended up sending a plaintext HTTP request to an HTTPS
port. The response ended up being Go telling as such, not the
/derp/latency-check handler's response (which is currently still a
404). But we didn't even get the 404.
This happened to work well enough because Go's built-in error response
was still a valid HTTP response that we can measure for timing
purposes, but it's not a great answer. Notably, it means we wouldn't
be able to get a future handler to run server-side and count those
latency requests.
This commit is contained in:
parent
1407540b52
commit
7f68e097dd
|
@ -142,8 +142,8 @@ func (c *Client) useHTTPS() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSServerName returns which TLS cert name to expect for the given node.
|
// tlsServerName returns which TLS cert name to expect for the given node.
|
||||||
func (c *Client) TLSServerName(node *tailcfg.DERPNode) string {
|
func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
|
||||||
if c.url != nil {
|
if c.url != nil {
|
||||||
return c.url.Host
|
return c.url.Host
|
||||||
}
|
}
|
||||||
|
@ -217,7 +217,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||||
tcpConn, err = c.dialURL(ctx)
|
tcpConn, err = c.dialURL(ctx)
|
||||||
} else {
|
} else {
|
||||||
c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode)
|
c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode)
|
||||||
tcpConn, node, err = c.DialRegion(ctx, reg)
|
tcpConn, node, err = c.dialRegion(ctx, reg)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -249,11 +249,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
||||||
|
|
||||||
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
|
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
|
||||||
if c.useHTTPS() {
|
if c.useHTTPS() {
|
||||||
tlsConf := tlsdial.Config(c.TLSServerName(node), c.TLSConfig)
|
httpConn = c.tlsClient(tcpConn, node)
|
||||||
if node != nil && node.DERPTestPort != 0 {
|
|
||||||
tlsConf.InsecureSkipVerify = true
|
|
||||||
}
|
|
||||||
httpConn = tls.Client(tcpConn, tlsConf)
|
|
||||||
} else {
|
} else {
|
||||||
httpConn = tcpConn
|
httpConn = tcpConn
|
||||||
}
|
}
|
||||||
|
@ -329,10 +325,10 @@ func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
|
||||||
return tcpConn, nil
|
return tcpConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialRegion returns a TCP connection to the provided region, trying
|
// dialRegion returns a TCP connection to the provided region, trying
|
||||||
// each node in order (with dialNode) until one connects or ctx is
|
// each node in order (with dialNode) until one connects or ctx is
|
||||||
// done.
|
// done.
|
||||||
func (c *Client) DialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) {
|
func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) {
|
||||||
if len(reg.Nodes) == 0 {
|
if len(reg.Nodes) == 0 {
|
||||||
return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg))
|
return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg))
|
||||||
}
|
}
|
||||||
|
@ -352,6 +348,42 @@ func (c *Client) DialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
|
||||||
return nil, nil, firstErr
|
return nil, nil, firstErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
|
||||||
|
tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig)
|
||||||
|
if node != nil && node.DERPTestPort != 0 {
|
||||||
|
tlsConf.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
return tls.Client(nc, tlsConf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tlsConn *tls.Conn, connClose io.Closer, err error) {
|
||||||
|
tcpConn, node, err := c.dialRegion(ctx, reg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
done := make(chan bool) // unbufferd
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
tlsConn = c.tlsClient(tcpConn, node)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
tcpConn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case done <- true:
|
||||||
|
return tlsConn, tcpConn, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
|
func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
|
||||||
var stdDialer dialer = netns.Dialer()
|
var stdDialer dialer = netns.Dialer()
|
||||||
var dialer = stdDialer
|
var dialer = stdDialer
|
||||||
|
|
|
@ -8,6 +8,7 @@ package netcheck
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -786,23 +787,26 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio
|
||||||
var ip netaddr.IP
|
var ip netaddr.IP
|
||||||
|
|
||||||
dc := derphttp.NewNetcheckClient(c.logf)
|
dc := derphttp.NewNetcheckClient(c.logf)
|
||||||
nc, node, err := dc.DialRegion(ctx, reg)
|
tlsConn, tcpConn, err := dc.DialRegionTLS(ctx, reg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ip, err
|
return 0, ip, err
|
||||||
}
|
}
|
||||||
defer nc.Close()
|
defer tcpConn.Close()
|
||||||
|
|
||||||
if ta, ok := nc.RemoteAddr().(*net.TCPAddr); ok {
|
if ta, ok := tlsConn.RemoteAddr().(*net.TCPAddr); ok {
|
||||||
ip, _ = netaddr.FromStdIP(ta.IP)
|
ip, _ = netaddr.FromStdIP(ta.IP)
|
||||||
}
|
}
|
||||||
if ip == (netaddr.IP{}) {
|
if ip == (netaddr.IP{}) {
|
||||||
return 0, ip, fmt.Errorf("no unexpected RemoteAddr %#v", nc.RemoteAddr())
|
return 0, ip, fmt.Errorf("no unexpected RemoteAddr %#v", tlsConn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
connc := make(chan net.Conn, 1)
|
connc := make(chan *tls.Conn, 1)
|
||||||
connc <- nc
|
connc <- tlsConn
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("unexpected DialContext dial")
|
||||||
|
},
|
||||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
select {
|
select {
|
||||||
case nc := <-connc:
|
case nc := <-connc:
|
||||||
|
@ -814,9 +818,7 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio
|
||||||
}
|
}
|
||||||
hc := &http.Client{Transport: tr}
|
hc := &http.Client{Transport: tr}
|
||||||
|
|
||||||
host := dc.TLSServerName(node)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://derp-unused-hostname.tld/derp/latency-check", nil)
|
||||||
u := fmt.Sprintf("https://%s/derp/latency-check", host)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ip, err
|
return 0, ip, err
|
||||||
}
|
}
|
||||||
|
@ -827,7 +829,7 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
_, err = io.Copy(ioutil.Discard, resp.Body)
|
_, err = io.Copy(ioutil.Discard, io.LimitReader(resp.Body, 8<<10))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ip, err
|
return 0, ip, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue