control/controlhttp: add AcceptHTTP hook to add coalesced Server->Client write
New plan for #5972. Instead of sending the public key in the clear
(from earlier unreleased 246274b8e9
) where the client might have to
worry about it being dropped or tampered with and retrying, we'll
instead send it post-Noise handshake but before the HTTP/2 connection
begins.
This replaces the earlier extraHeaders hook with a different sort of
hook that allows us to combine two writes on the wire in one packet.
Updates #5972
Change-Id: I42cdf7c1859b53ca4dfa5610bd1b840c6986e09c
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
c21a3c4733
commit
5e9e57ecf5
|
@ -37,6 +37,8 @@ type httpTestParam struct {
|
|||
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
|
||||
// 101 switching protocols.
|
||||
makeHTTPHangAfterUpgrade bool
|
||||
|
||||
doEarlyWrite bool
|
||||
}
|
||||
|
||||
func TestControlHTTP(t *testing.T) {
|
||||
|
@ -111,6 +113,11 @@ func TestControlHTTP(t *testing.T) {
|
|||
allowHTTP: true,
|
||||
},
|
||||
},
|
||||
// Early write
|
||||
{
|
||||
name: "early_write",
|
||||
doEarlyWrite: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -125,9 +132,21 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
|||
client, server := key.NewMachine(), key.NewMachine()
|
||||
|
||||
const testProtocolVersion = 1
|
||||
const earlyWriteMsg = "Hello, world!"
|
||||
sch := make(chan serverResult, 1)
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
|
||||
var earlyWriteFn func(protocolVersion int, w io.Writer) error
|
||||
if param.doEarlyWrite {
|
||||
earlyWriteFn = func(protocolVersion int, w io.Writer) error {
|
||||
if protocolVersion != testProtocolVersion {
|
||||
t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
|
||||
return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
|
||||
}
|
||||
_, err := io.WriteString(w, earlyWriteMsg)
|
||||
return err
|
||||
}
|
||||
}
|
||||
conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
|
@ -228,6 +247,15 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
|||
if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
|
||||
t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
|
||||
}
|
||||
if param.doEarlyWrite {
|
||||
buf := make([]byte, len(earlyWriteMsg))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
t.Fatalf("reading early write: %v", err)
|
||||
}
|
||||
if string(buf) != earlyWriteMsg {
|
||||
t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type serverResult struct {
|
||||
|
|
|
@ -9,7 +9,10 @@ import (
|
|||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/control/controlbase"
|
||||
|
@ -18,16 +21,20 @@ import (
|
|||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// AcceptHTTP upgrades the HTTP request given by w and r into a
|
||||
// Tailscale control protocol base transport connection.
|
||||
// AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale
|
||||
// control protocol base transport connection.
|
||||
//
|
||||
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
||||
// attempt their own response after calling AcceptHTTP.
|
||||
// AcceptHTTP always writes an HTTP response to w. The caller must not attempt
|
||||
// their own response after calling AcceptHTTP.
|
||||
//
|
||||
// extraHeader optionally specifies extra header(s) to send in the
|
||||
// 101 Switching Protocols Upgrade response. It must not include the "Upgrade"
|
||||
// or "Connection" headers; they will be replaced.
|
||||
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) {
|
||||
// earlyWrite optionally specifies a func to write to the noise connection
|
||||
// (encrypted). It receives the negotiated version and a writer to write to, if
|
||||
// desired.
|
||||
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (*controlbase.Conn, error) {
|
||||
return acceptHTTP(ctx, w, r, private, earlyWrite)
|
||||
}
|
||||
|
||||
func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, earlyWrite func(protocolVersion int, w io.Writer) error) (_ *controlbase.Conn, retErr error) {
|
||||
next := r.Header.Get("Upgrade")
|
||||
if next == "" {
|
||||
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
||||
|
@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
|||
return nil, errors.New("can't hijack client connection")
|
||||
}
|
||||
|
||||
for k, vv := range extraHeader {
|
||||
w.Header()[k] = vv
|
||||
}
|
||||
w.Header().Set("Upgrade", upgradeHeaderValue)
|
||||
w.Header().Set("Connection", "upgrade")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("hijacking client connection: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err)
|
||||
}
|
||||
conn = netutil.NewDrainBufConn(conn, brw.Reader)
|
||||
|
||||
nc, err := controlbase.Server(ctx, conn, private, init)
|
||||
cwc := newWriteCorkingConn(conn)
|
||||
|
||||
nc, err := controlbase.Server(ctx, cwc, private, init)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("noise handshake failed: %w", err)
|
||||
}
|
||||
|
||||
if earlyWrite != nil {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("setting conn deadline: %w", err)
|
||||
}
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
}
|
||||
if err := earlyWrite(nc.ProtocolVersion(), nc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := cwc.uncork(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
|
@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
|
|||
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
// corkConn is a net.Conn wrapper that initially buffers all writes until uncork
|
||||
// is called. If the conn is corked and a Read occurs, the Read will flush any
|
||||
// buffered (corked) write.
|
||||
//
|
||||
// Until uncorked, Read/Write/uncork may be not called concurrently.
|
||||
//
|
||||
// Deadlines still work, but a corked write ignores deadlines until a Read or
|
||||
// uncork goes to do that Write.
|
||||
//
|
||||
// Use newWriteCorkingConn to create one.
|
||||
type corkConn struct {
|
||||
net.Conn
|
||||
corked bool
|
||||
buf []byte // corked data
|
||||
}
|
||||
|
||||
func newWriteCorkingConn(c net.Conn) *corkConn {
|
||||
return &corkConn{Conn: c, corked: true}
|
||||
}
|
||||
|
||||
func (c *corkConn) Write(b []byte) (int, error) {
|
||||
if c.corked {
|
||||
c.buf = append(c.buf, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *corkConn) Read(b []byte) (int, error) {
|
||||
if c.corked {
|
||||
if err := c.flush(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
// uncork flushes any buffered data and uncorks the connection so future Writes
|
||||
// don't buffer. It may not be called concurrently with reads or writes and
|
||||
// may only be called once.
|
||||
func (c *corkConn) uncork() error {
|
||||
if !c.corked {
|
||||
panic("usage error; uncork called twice") // worth panicking to catch misuse
|
||||
}
|
||||
err := c.flush()
|
||||
c.corked = false
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *corkConn) flush() error {
|
||||
if len(c.buf) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := c.Conn.Write(c.buf)
|
||||
c.buf = nil
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue