control/controlhttp: add AcceptHTTP hook to add coalesced Server->Client write
New plan for #5972. Instead of sending the public key in the clear
(from earlier unreleased 246274b8e9
) where the client might have to
worry about it being dropped or tampered with and retrying, we'll
instead send it post-Noise handshake but before the HTTP/2 connection
begins.
This replaces the earlier extraHeaders hook with a different sort of
hook that allows us to combine two writes on the wire in one packet.
Updates #5972
Change-Id: I42cdf7c1859b53ca4dfa5610bd1b840c6986e09c
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
c21a3c4733
commit
5e9e57ecf5
|
@ -37,6 +37,8 @@ type httpTestParam struct {
|
||||||
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
|
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
|
||||||
// 101 switching protocols.
|
// 101 switching protocols.
|
||||||
makeHTTPHangAfterUpgrade bool
|
makeHTTPHangAfterUpgrade bool
|
||||||
|
|
||||||
|
doEarlyWrite bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestControlHTTP(t *testing.T) {
|
func TestControlHTTP(t *testing.T) {
|
||||||
|
@ -111,6 +113,11 @@ func TestControlHTTP(t *testing.T) {
|
||||||
allowHTTP: true,
|
allowHTTP: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// Early write
|
||||||
|
{
|
||||||
|
name: "early_write",
|
||||||
|
doEarlyWrite: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
@ -125,9 +132,21 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||||
client, server := key.NewMachine(), key.NewMachine()
|
client, server := key.NewMachine(), key.NewMachine()
|
||||||
|
|
||||||
const testProtocolVersion = 1
|
const testProtocolVersion = 1
|
||||||
|
const earlyWriteMsg = "Hello, world!"
|
||||||
sch := make(chan serverResult, 1)
|
sch := make(chan serverResult, 1)
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
|
var earlyWriteFn func(protocolVersion int, w io.Writer) error
|
||||||
|
if param.doEarlyWrite {
|
||||||
|
earlyWriteFn = func(protocolVersion int, w io.Writer) error {
|
||||||
|
if protocolVersion != testProtocolVersion {
|
||||||
|
t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
|
||||||
|
return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
|
||||||
|
}
|
||||||
|
_, err := io.WriteString(w, earlyWriteMsg)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
|
@ -228,6 +247,15 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
||||||
if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
|
if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
|
||||||
t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
|
t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
|
||||||
}
|
}
|
||||||
|
if param.doEarlyWrite {
|
||||||
|
buf := make([]byte, len(earlyWriteMsg))
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
t.Fatalf("reading early write: %v", err)
|
||||||
|
}
|
||||||
|
if string(buf) != earlyWriteMsg {
|
||||||
|
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverResult struct {
|
type serverResult struct {
|
||||||
|
|
|
@ -9,7 +9,10 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
|
@ -18,16 +21,20 @@ import (
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AcceptHTTP upgrades the HTTP request given by w and r into a
|
// AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale
|
||||||
// Tailscale control protocol base transport connection.
|
// control protocol base transport connection.
|
||||||
//
|
//
|
||||||
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
// AcceptHTTP always writes an HTTP response to w. The caller must not attempt
|
||||||
// attempt their own response after calling AcceptHTTP.
|
// their own response after calling AcceptHTTP.
|
||||||
//
|
//
|
||||||
// extraHeader optionally specifies extra header(s) to send in the
|
// earlyWrite optionally specifies a func to write to the noise connection
|
||||||
// 101 Switching Protocols Upgrade response. It must not include the "Upgrade"
|
// (encrypted). It receives the negotiated version and a writer to write to, if
|
||||||
// or "Connection" headers; they will be replaced.
|
// desired.
|
||||||
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) {
|
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) {
|
||||||
|
return acceptHTTP(ctx, w, r, private, earlyWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) {
|
||||||
next := r.Header.Get("Upgrade")
|
next := r.Header.Get("Upgrade")
|
||||||
if next == "" {
|
if next == "" {
|
||||||
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
||||||
|
@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
||||||
return nil, errors.New("can't hijack client connection")
|
return nil, errors.New("can't hijack client connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, vv := range extraHeader {
|
|
||||||
w.Header()[k] = vv
|
|
||||||
}
|
|
||||||
w.Header().Set("Upgrade", upgradeHeaderValue)
|
w.Header().Set("Upgrade", upgradeHeaderValue)
|
||||||
w.Header().Set("Connection", "upgrade")
|
w.Header().Set("Connection", "upgrade")
|
||||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("hijacking client connection: %w", err)
|
return nil, fmt.Errorf("hijacking client connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if retErr != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if err := brw.Flush(); err != nil {
|
if err := brw.Flush(); err != nil {
|
||||||
conn.Close()
|
|
||||||
return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err)
|
return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err)
|
||||||
}
|
}
|
||||||
conn = netutil.NewDrainBufConn(conn, brw.Reader)
|
conn = netutil.NewDrainBufConn(conn, brw.Reader)
|
||||||
|
|
||||||
nc, err := controlbase.Server(ctx, conn, private, init)
|
cwc := newWriteCorkingConn(conn)
|
||||||
|
|
||||||
|
nc, err := controlbase.Server(ctx, cwc, private, init)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
|
||||||
return nil, fmt.Errorf("noise handshake failed: %w", err)
|
return nil, fmt.Errorf("noise handshake failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if earlyWrite != nil {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.SetDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cwc.uncork(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return nc, nil
|
return nc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||||
|
|
||||||
return nc, nil
|
return nc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// corkConn is a net.Conn wrapper that initially buffers all writes until uncork
|
||||||
|
// is called. If the conn is corked and a Read occurs, the Read will flush any
|
||||||
|
// buffered (corked) write.
|
||||||
|
//
|
||||||
|
// Until uncorked, Read/Write/uncork may be not called concurrently.
|
||||||
|
//
|
||||||
|
// Deadlines still work, but a corked write ignores deadlines until a Read or
|
||||||
|
// uncork goes to do that Write.
|
||||||
|
//
|
||||||
|
// Use newWriteCorkingConn to create one.
|
||||||
|
type corkConn struct {
|
||||||
|
net.Conn
|
||||||
|
corked bool
|
||||||
|
buf []byte // corked data
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWriteCorkingConn(c net.Conn) *corkConn {
|
||||||
|
return &corkConn{Conn: c, corked: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *corkConn) Write(b []byte) (int, error) {
|
||||||
|
if c.corked {
|
||||||
|
c.buf = append(c.buf, b...)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *corkConn) Read(b []byte) (int, error) {
|
||||||
|
if c.corked {
|
||||||
|
if err := c.flush(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.Conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// uncork flushes any buffered data and uncorks the connection so future Writes
|
||||||
|
// don't buffer. It may not be called concurrently with reads or writes and
|
||||||
|
// may only be called once.
|
||||||
|
func (c *corkConn) uncork() error {
|
||||||
|
if !c.corked {
|
||||||
|
panic("usage error; uncork called twice") // worth panicking to catch misuse
|
||||||
|
}
|
||||||
|
err := c.flush()
|
||||||
|
c.corked = false
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *corkConn) flush() error {
|
||||||
|
if len(c.buf) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := c.Conn.Write(c.buf)
|
||||||
|
c.buf = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue