From dd46cd5f36ae31dfb4f31dcb4dec65921510db17 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Wed, 19 Aug 2020 18:35:54 +0300 Subject: [PATCH] * copy dhcpv4/nclient4 package with minor enhancement The original package can be built only on Linux. --- dhcpd/check_other_dhcp.go | 2 +- dhcpd/nclient4/client.go | 584 ++++++++++++++++++++++++++++++++++ dhcpd/nclient4/client_test.go | 333 +++++++++++++++++++ dhcpd/nclient4/conn_unix.go | 144 +++++++++ dhcpd/nclient4/ipv4.go | 376 ++++++++++++++++++++++ go.mod | 22 +- go.sum | 23 +- 7 files changed, 1471 insertions(+), 13 deletions(-) create mode 100644 dhcpd/nclient4/client.go create mode 100644 dhcpd/nclient4/client_test.go create mode 100644 dhcpd/nclient4/conn_unix.go create mode 100644 dhcpd/nclient4/ipv4.go diff --git a/dhcpd/check_other_dhcp.go b/dhcpd/check_other_dhcp.go index d8c8aa33..ebcfd363 100644 --- a/dhcpd/check_other_dhcp.go +++ b/dhcpd/check_other_dhcp.go @@ -10,9 +10,9 @@ import ( "runtime" "time" + "github.com/AdguardTeam/AdGuardHome/dhcpd/nclient4" "github.com/AdguardTeam/golibs/log" "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/nclient6" "github.com/insomniacslk/dhcp/iana" diff --git a/dhcpd/nclient4/client.go b/dhcpd/nclient4/client.go new file mode 100644 index 00000000..f2eabb5d --- /dev/null +++ b/dhcpd/nclient4/client.go @@ -0,0 +1,584 @@ +// Copyright 2018 the u-root Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 + +// Package nclient4 is a small, minimum-functionality client for DHCPv4. +// +// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as +// well as the Request-Ack renewal process. +// Originally from here: github.com/insomniacslk/dhcp/dhcpv4/nclient4 +// with the difference that this package can be built on UNIX (not just Linux), +// because github.com/mdlayher/raw package supports it. +package nclient4 + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/insomniacslk/dhcp/dhcpv4" +) + +const ( + defaultBufferCap = 5 + + // DefaultTimeout is the default value for read-timeout if option WithTimeout is not set + DefaultTimeout = 5 * time.Second + + // DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time + DefaultRetries = 3 + + // MaxMessageSize is the value to be used for DHCP option "MaxMessageSize". + MaxMessageSize = 1500 + + // ClientPort is the port that DHCP clients listen on. + ClientPort = 68 + + // ServerPort is the port that DHCP servers and relay agents listen on. + ServerPort = 67 +) + +var ( + // DefaultServers is the address of all link-local DHCP servers and + // relay agents. + DefaultServers = &net.UDPAddr{ + IP: net.IPv4bcast, + Port: ServerPort, + } +) + +var ( + // ErrNoResponse is returned when no response packet is received. + ErrNoResponse = errors.New("no matching response packet received") + + // ErrNoConn is returned when NewWithConn is called with nil-value as conn. + ErrNoConn = errors.New("conn is nil") + + // ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr + ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil") +) + +// pendingCh is a channel associated with a pending TransactionID. +type pendingCh struct { + // SendAndRead closes done to indicate that it wishes for no more + // messages for this particular XID. + done <-chan struct{} + + // ch is used by the receive loop to distribute DHCP messages. + ch chan<- *dhcpv4.DHCPv4 +} + +// Logger is a handler which will be used to output logging messages +type Logger interface { + // PrintMessage print _all_ DHCP messages + PrintMessage(prefix string, message *dhcpv4.DHCPv4) + + // Printf is use to print the rest debugging information + Printf(format string, v ...interface{}) +} + +// EmptyLogger prints nothing +type EmptyLogger struct{} + +// Printf is just a dummy function that does nothing +func (e EmptyLogger) Printf(format string, v ...interface{}) {} + +// PrintMessage is just a dummy function that does nothing +func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {} + +// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer. +type Printfer interface { + // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf. + Printf(format string, v ...interface{}) +} + +// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the short format. +type ShortSummaryLogger struct { + // Printfer is used for actual output of the logger + Printfer +} + +// Printf prints a log message as-is via predefined Printfer +func (s ShortSummaryLogger) Printf(format string, v ...interface{}) { + s.Printfer.Printf(format, v...) +} + +// PrintMessage prints a DHCP message in the short format via predefined Printfer +func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + s.Printf("%s: %s", prefix, message) +} + +// DebugLogger is a wrapper for Printfer to implement interface Logger. +// DHCP messages are printed in the long format. +type DebugLogger struct { + // Printfer is used for actual output of the logger + Printfer +} + +// Printf prints a log message as-is via predefined Printfer +func (d DebugLogger) Printf(format string, v ...interface{}) { + d.Printfer.Printf(format, v...) +} + +// PrintMessage prints a DHCP message in the long format via predefined Printfer +func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { + d.Printf("%s: %s", prefix, message.Summary()) +} + +// Client is an IPv4 DHCP client. +type Client struct { + ifaceHWAddr net.HardwareAddr + conn net.PacketConn + timeout time.Duration + retry int + logger Logger + + // bufferCap is the channel capacity for each TransactionID. + bufferCap int + + // serverAddr is the UDP address to send all packets to. + // + // This may be an actual broadcast address, or a unicast address. + serverAddr *net.UDPAddr + + // closed is an atomic bool set to 1 when done is closed. + closed uint32 + + // done is closed to unblock the receive loop. + done chan struct{} + + // wg protects any spawned goroutines, namely the receiveLoop. + wg sync.WaitGroup + + pendingMu sync.Mutex + // pending stores the distribution channels for each pending + // TransactionID. receiveLoop uses this map to determine which channel + // to send a new DHCP message to. + pending map[dhcpv4.TransactionID]*pendingCh +} + +// New returns a client usable with an unconfigured interface. +func New(iface string, opts ...ClientOpt) (*Client, error) { + return new(iface, nil, nil, opts...) +} + +// NewWithConn creates a new DHCP client that sends and receives packets on the +// given interface. +func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { + return new(``, conn, ifaceHWAddr, opts...) +} + +func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { + c := &Client{ + ifaceHWAddr: ifaceHWAddr, + timeout: DefaultTimeout, + retry: DefaultRetries, + serverAddr: DefaultServers, + bufferCap: defaultBufferCap, + conn: conn, + logger: EmptyLogger{}, + + done: make(chan struct{}), + pending: make(map[dhcpv4.TransactionID]*pendingCh), + } + + for _, opt := range opts { + err := opt(c) + if err != nil { + return nil, fmt.Errorf("unable to apply option: %w", err) + } + } + + if c.ifaceHWAddr == nil { + if iface == `` { + return nil, ErrNoIfaceHWAddr + } + + i, err := net.InterfaceByName(iface) + if err != nil { + return nil, fmt.Errorf("unable to get interface information: %w", err) + } + + c.ifaceHWAddr = i.HardwareAddr + } + + if c.conn == nil { + var err error + if iface == `` { + return nil, ErrNoConn + } + c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast + if err != nil { + return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err) + } + } + c.wg.Add(1) + go c.receiveLoop() + return c, nil +} + +// Close closes the underlying connection. +func (c *Client) Close() error { + // Make sure not to close done twice. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return nil + } + + err := c.conn.Close() + + // Closing c.done sets off a chain reaction: + // + // Any SendAndRead unblocks trying to receive more messages, which + // means rem() gets called. + // + // rem() should be unblocking receiveLoop if it is blocked. + // + // receiveLoop should then exit gracefully. + close(c.done) + + // Wait for receiveLoop to stop. + c.wg.Wait() + + return err +} + +func (c *Client) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *Client) receiveLoop() { + defer c.wg.Done() + for { + // TODO: Clients can send a "max packet size" option in their + // packets, IIRC. Choose a reasonable size and set it. + b := make([]byte, MaxMessageSize) + n, _, err := c.conn.ReadFrom(b) + if err != nil { + if !c.isClosed() { + c.logger.Printf("error reading from UDP connection: %v", err) + } + return + } + + msg, err := dhcpv4.FromBytes(b[:n]) + if err != nil { + // Not a valid DHCP packet; keep listening. + continue + } + + if msg.OpCode != dhcpv4.OpcodeBootReply { + // Not a response message. + continue + } + + // This is a somewhat non-standard check, by the looks + // of RFC 2131. It should work as long as the DHCP + // server is spec-compliant for the HWAddr field. + if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) { + // Not for us. + continue + } + + c.pendingMu.Lock() + p, ok := c.pending[msg.TransactionID] + if ok { + select { + case <-p.done: + close(p.ch) + delete(c.pending, msg.TransactionID) + + // This send may block. + case p.ch <- msg: + } + } + c.pendingMu.Unlock() + } +} + +// ClientOpt is a function that configures the Client. +type ClientOpt func(c *Client) error + +// WithTimeout configures the retransmission timeout. +// +// Default is 5 seconds. +func WithTimeout(d time.Duration) ClientOpt { + return func(c *Client) (err error) { + c.timeout = d + return + } +} + +// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received. +func WithSummaryLogger() ClientOpt { + return func(c *Client) (err error) { + c.logger = ShortSummaryLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + return + } +} + +// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received. +func WithDebugLogger() ClientOpt { + return func(c *Client) (err error) { + c.logger = DebugLogger{ + Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags), + } + return + } +} + +// WithLogger set the logger (see interface Logger). +func WithLogger(newLogger Logger) ClientOpt { + return func(c *Client) (err error) { + c.logger = newLogger + return + } +} + +// WithUnicast forces client to send messages as unicast frames. +// By default client sends messages as broadcast frames even if server address is defined. +// +// srcAddr is both: +// * The source address of outgoing frames. +// * The address to be listened for incoming frames. +func WithUnicast(srcAddr *net.UDPAddr) ClientOpt { + return func(c *Client) (err error) { + if srcAddr == nil { + srcAddr = &net.UDPAddr{Port: ServerPort} + } + c.conn, err = net.ListenUDP("udp4", srcAddr) + if err != nil { + err = fmt.Errorf("unable to start listening UDP port: %w", err) + } + return + } +} + +// WithHWAddr tells to the Client to receive messages destinated to selected +// hardware address +func WithHWAddr(hwAddr net.HardwareAddr) ClientOpt { + return func(c *Client) (err error) { + c.ifaceHWAddr = hwAddr + return + } +} + +func withBufferCap(n int) ClientOpt { + return func(c *Client) (err error) { + c.bufferCap = n + return + } +} + +// WithRetry configures the number of retransmissions to attempt. +// +// Default is 3. +func WithRetry(r int) ClientOpt { + return func(c *Client) (err error) { + c.retry = r + return + } +} + +// WithServerAddr configures the address to send messages to. +func WithServerAddr(n *net.UDPAddr) ClientOpt { + return func(c *Client) (err error) { + c.serverAddr = n + return + } +} + +// Matcher matches DHCP packets. +type Matcher func(*dhcpv4.DHCPv4) bool + +// IsMessageType returns a matcher that checks for the message type. +// +// If t is MessageTypeNone, all packets are matched. +func IsMessageType(t dhcpv4.MessageType) Matcher { + return func(p *dhcpv4.DHCPv4) bool { + return p.MessageType() == t || t == dhcpv4.MessageTypeNone + } +} + +// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer +// received. +func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) { + // RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should + // contain. + discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers, + dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) + if err != nil { + err = fmt.Errorf("unable to create a discovery request: %w", err) + return + } + + offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer)) + if err != nil { + err = fmt.Errorf("got an error while the discovery request: %w", err) + return + } + + return +} + +// Request completes the 4-way Discover-Offer-Request-Ack handshake. +// +// Note that modifiers will be applied *both* to Discover and Request packets. +func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) { + offer, err = c.DiscoverOffer(ctx, modifiers...) + if err != nil { + err = fmt.Errorf("unable to receive an offer: %w", err) + return + } + + // TODO(chrisko): should this be unicast to the server? + request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers, + dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) + if err != nil { + err = fmt.Errorf("unable to create a request: %w", err) + return + } + + ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil) + if err != nil { + err = fmt.Errorf("got an error while processing the request: %w", err) + return + } + + return +} + +// ErrTransactionIDInUse is returned if there were an attempt to send a message +// with the same TransactionID as we are already waiting an answer for. +type ErrTransactionIDInUse struct { + // TransactionID is the transaction ID of the message which the error is related to. + TransactionID dhcpv4.TransactionID +} + +// Error is just the method to comply interface "error". +func (err *ErrTransactionIDInUse) Error() string { + return fmt.Sprintf("transaction ID %s already in use", err.TransactionID) +} + +// send sends p to destination and returns a response channel. +// +// Responses will be matched by transaction ID and ClientHWAddr. +// +// The returned lambda function must be called after all desired responses have +// been received in order to return the Transaction ID to the usable pool. +func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) { + c.pendingMu.Lock() + if _, ok := c.pending[msg.TransactionID]; ok { + c.pendingMu.Unlock() + return nil, nil, &ErrTransactionIDInUse{msg.TransactionID} + } + + ch := make(chan *dhcpv4.DHCPv4, c.bufferCap) + done := make(chan struct{}) + c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch} + c.pendingMu.Unlock() + + cancel = func() { + // Why can't we just close ch here? + // + // Because receiveLoop may potentially be blocked trying to + // send on ch. We gotta unblock it first, and then we can take + // the lock and remove the XID from the pending transaction + // map. + close(done) + + c.pendingMu.Lock() + if p, ok := c.pending[msg.TransactionID]; ok { + close(p.ch) + delete(c.pending, msg.TransactionID) + } + c.pendingMu.Unlock() + } + + if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil { + cancel() + return nil, nil, fmt.Errorf("error writing packet to connection: %w", err) + } + return ch, cancel, nil +} + +// This error should never be visible to users. +// It is used only to increase the timeout in retryFn. +var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded") + +// SendAndRead sends a packet p to a destination dest and waits for the first +// response matching `match` as well as its Transaction ID and ClientHWAddr. +// +// If match is nil, the first packet matching the Transaction ID and +// ClientHWAddr is returned. +func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) { + var response *dhcpv4.DHCPv4 + err := c.retryFn(func(timeout time.Duration) error { + ch, rem, err := c.send(dest, p) + if err != nil { + return err + } + c.logger.PrintMessage("sent message", p) + defer rem() + + for { + select { + case <-c.done: + return ErrNoResponse + + case <-time.After(timeout): + return errDeadlineExceeded + + case <-ctx.Done(): + return ctx.Err() + + case packet := <-ch: + if match == nil || match(packet) { + c.logger.PrintMessage("received message", packet) + response = packet + return nil + } + } + } + }) + if err == errDeadlineExceeded { + return nil, ErrNoResponse + } + if err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) retryFn(fn func(timeout time.Duration) error) error { + timeout := c.timeout + + // Each retry takes the amount of timeout at worst. + for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)? + switch err := fn(timeout); err { + case nil: + // Got it! + return nil + + case errDeadlineExceeded: + // Double timeout, then retry. + timeout *= 2 + + default: + return err + } + } + + return errDeadlineExceeded +} diff --git a/dhcpd/nclient4/client_test.go b/dhcpd/nclient4/client_test.go new file mode 100644 index 00000000..ddad1144 --- /dev/null +++ b/dhcpd/nclient4/client_test.go @@ -0,0 +1,333 @@ +// Copyright 2018 the u-root Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build linux +// github.com/hugelgupf/socketpair is Linux-only +// +build go1.12 + +package nclient4 + +import ( + "bytes" + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/hugelgupf/socketpair" + "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/insomniacslk/dhcp/dhcpv4/server4" +) + +type handler struct { + mu sync.Mutex + received []*dhcpv4.DHCPv4 + + // Each received packet can have more than one response (in theory, + // from different servers sending different Advertise, for example). + responses [][]*dhcpv4.DHCPv4 +} + +func (h *handler) handle(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { + h.mu.Lock() + defer h.mu.Unlock() + + h.received = append(h.received, m) + + if len(h.responses) > 0 { + for _, resp := range h.responses[0] { + _, _ = conn.WriteTo(resp.ToBytes(), peer) + } + h.responses = h.responses[1:] + } +} + +func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...ClientOpt) (*Client, net.PacketConn) { + // Fake PacketConn connection. + clientRawConn, serverRawConn, err := socketpair.PacketSocketPair() + if err != nil { + panic(err) + } + + clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort}) + serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort}) + + o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)} + o = append(o, opts...) + mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...) + if err != nil { + panic(err) + } + + h := &handler{responses: responses} + s, err := server4.NewServer("", nil, h.handle, server4.WithConn(serverConn)) + if err != nil { + panic(err) + } + go func() { + _ = s.Serve() + }() + + return mc, serverConn +} + +func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error { + if got == nil && got == want { + return nil + } + if (want == nil || got == nil) && (got != want) { + return fmt.Errorf("packet got %v, want %v", got, want) + } + if !bytes.Equal(got.ToBytes(), want.ToBytes()) { + return fmt.Errorf("packet got %v, want %v", got, want) + } + return nil +} + +func pktsExpected(got []*dhcpv4.DHCPv4, want []*dhcpv4.DHCPv4) error { + if len(got) != len(want) { + return fmt.Errorf("got %d packets, want %d packets", len(got), len(want)) + } + + for i := range got { + if err := ComparePacket(got[i], want[i]); err != nil { + return err + } + } + return nil +} + +func newPacketWeirdHWAddr(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 { + p, err := dhcpv4.New() + if err != nil { + panic(fmt.Sprintf("newpacket: %v", err)) + } + p.OpCode = op + p.TransactionID = xid + p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 1, 2, 3, 4, 5, 6} + return p +} + +func newPacket(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 { + p, err := dhcpv4.New() + if err != nil { + panic(fmt.Sprintf("newpacket: %v", err)) + } + p.OpCode = op + p.TransactionID = xid + p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf} + return p +} + +func TestSendAndRead(t *testing.T) { + for _, tt := range []struct { + desc string + send *dhcpv4.DHCPv4 + server []*dhcpv4.DHCPv4 + + // If want is nil, we assume server[0] contains what is wanted. + want *dhcpv4.DHCPv4 + wantErr error + }{ + { + desc: "two response packets", + send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + server: []*dhcpv4.DHCPv4{ + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + { + desc: "one response packet", + send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + server: []*dhcpv4.DHCPv4{ + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + { + desc: "one response packet, one invalid XID, one invalid opcode, one invalid hwaddr", + send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + server: []*dhcpv4.DHCPv4{ + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x77, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacketWeirdHWAddr(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + { + desc: "discard wrong XID", + send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + server: []*dhcpv4.DHCPv4{ + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0, 0, 0, 0}), + }, + want: nil, // Explicitly empty. + wantErr: ErrNoResponse, + }, + { + desc: "no response, timeout", + send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + wantErr: ErrNoResponse, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + // Both server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{tt.server}, + // Use an unbuffered channel to make sure we + // have no deadlocks. + withBufferCap(0)) + defer mc.Close() + + rcvd, err := mc.SendAndRead(context.Background(), DefaultServers, tt.send, nil) + if err != tt.wantErr { + t.Error(err) + } + + if err := ComparePacket(rcvd, tt.want); err != nil { + t.Errorf("got unexpected packets: %v", err) + } + }) + } +} + +func TestParallelSendAndRead(t *testing.T) { + pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{}, + WithTimeout(10*time.Second), + // Use an unbuffered channel to make sure nothing blocks. + withBufferCap(0)) + defer mc.Close() + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { + t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(4 * time.Second) + + if err := mc.Close(); err != nil { + t.Errorf("closing failed: %v", err) + } + }() + + wg.Wait() +} + +func TestReuseXID(t *testing.T) { + pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{}) + defer mc.Close() + + if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { + t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) + } + + if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { + t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) + } +} + +func TestSimpleSendAndReadDiscardGarbage(t *testing.T) { + pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) + + responses := newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}) + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, udpConn := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{{responses}}) + defer mc.Close() + + // Too short for valid DHCPv4 packet. + _, _ = udpConn.WriteTo([]byte{0x01}, nil) + _, _ = udpConn.WriteTo([]byte{0x01, 0x2}, nil) + + rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil) + if err != nil { + t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err) + } + + if err := ComparePacket(rcvd, responses); err != nil { + t.Errorf("got unexpected packets: %v", err) + } +} + +func TestMultipleSendAndRead(t *testing.T) { + for _, tt := range []struct { + desc string + send []*dhcpv4.DHCPv4 + server [][]*dhcpv4.DHCPv4 + wantErr []error + }{ + { + desc: "two requests, two responses", + send: []*dhcpv4.DHCPv4{ + newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), + newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x44, 0x44, 0x44, 0x44}), + }, + server: [][]*dhcpv4.DHCPv4{ + []*dhcpv4.DHCPv4{ // Response for first packet. + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), + }, + []*dhcpv4.DHCPv4{ // Response for second packet. + newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x44, 0x44, 0x44, 0x44}), + }, + }, + wantErr: []error{ + nil, + nil, + }, + }, + } { + // Both server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, tt.server) + defer mc.Close() + + for i, send := range tt.send { + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + rcvd, err := mc.SendAndRead(ctx, DefaultServers, send, nil) + + if wantErr := tt.wantErr[i]; err != wantErr { + t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr) + } + if err := pktsExpected([]*dhcpv4.DHCPv4{rcvd}, tt.server[i]); err != nil { + t.Errorf("got unexpected packets: %v", err) + } + } + } +} diff --git a/dhcpd/nclient4/conn_unix.go b/dhcpd/nclient4/conn_unix.go new file mode 100644 index 00000000..51ec98cb --- /dev/null +++ b/dhcpd/nclient4/conn_unix.go @@ -0,0 +1,144 @@ +// Copyright 2018 the u-root Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd solaris +// +build go1.12 + +package nclient4 + +import ( + "errors" + "io" + "net" + + "github.com/mdlayher/ethernet" + "github.com/mdlayher/raw" + "github.com/u-root/u-root/pkg/uio" +) + +var ( + // BroadcastMac is the broadcast MAC address. + // + // Any UDP packet sent to this address is broadcast on the subnet. + BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255}) +) + +var ( + // ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr". + ErrUDPAddrIsRequired = errors.New("must supply UDPAddr") +) + +// NewRawUDPConn returns a UDP connection bound to the interface and port +// given based on a raw packet socket. All packets are broadcasted. +// +// The interface can be completely unconfigured. +func NewRawUDPConn(iface string, port int) (net.PacketConn, error) { + ifc, err := net.InterfaceByName(iface) + if err != nil { + return nil, err + } + rawConn, err := raw.ListenPacket(ifc, uint16(ethernet.EtherTypeIPv4), &raw.Config{LinuxSockDGRAM: true}) + if err != nil { + return nil, err + } + return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil +} + +// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast +// MAC address. +type BroadcastRawUDPConn struct { + // PacketConn is a raw DGRAM socket. + net.PacketConn + + // boundAddr is the address this RawUDPConn is "bound" to. + // + // Calls to ReadFrom will only return packets destined to this address. + boundAddr *net.UDPAddr +} + +// NewBroadcastUDPConn returns a PacketConn that marshals and unmarshals UDP +// packets, sending them to the broadcast MAC at on rawPacketConn. +// +// Calls to ReadFrom will only return packets destined to boundAddr. +func NewBroadcastUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn { + return &BroadcastRawUDPConn{ + PacketConn: rawPacketConn, + boundAddr: boundAddr, + } +} + +func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool { + if bound == nil { + return true + } + if bound.IP != nil && !bound.IP.Equal(addr.IP) { + return false + } + return bound.Port == addr.Port +} + +// ReadFrom implements net.PacketConn.ReadFrom. +// +// ReadFrom reads raw IP packets and will try to match them against +// upc.boundAddr. Any matching packets are returned via the given buffer. +func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + ipHdrMaxLen := IPv4MaximumHeaderSize + udpHdrLen := UDPMinimumSize + + for { + pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b)) + n, _, err := upc.PacketConn.ReadFrom(pkt) + if err != nil { + return 0, nil, err + } + if n == 0 { + return 0, nil, io.EOF + } + pkt = pkt[:n] + buf := uio.NewBigEndianBuffer(pkt) + + // To read the header length, access data directly. + ipHdr := IPv4(buf.Data()) + ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength()))) + + if ipHdr.TransportProtocol() != UDPProtocolNumber { + continue + } + udpHdr := UDP(buf.Consume(udpHdrLen)) + + addr := &net.UDPAddr{ + IP: ipHdr.DestinationAddress(), + Port: int(udpHdr.DestinationPort()), + } + if !udpMatch(addr, upc.boundAddr) { + continue + } + srcAddr := &net.UDPAddr{ + IP: ipHdr.SourceAddress(), + Port: int(udpHdr.SourcePort()), + } + // Extra padding after end of IP packet should be ignored, + // if not dhcp option parsing will fail. + dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen + return copy(b, buf.Consume(dhcpLen)), srcAddr, nil + } +} + +// WriteTo implements net.PacketConn.WriteTo and broadcasts all packets at the +// raw socket level. +// +// WriteTo wraps the given packet in the appropriate UDP and IP header before +// sending it on the packet conn. +func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, ErrUDPAddrIsRequired + } + + // Using the boundAddr is not quite right here, but it works. + packet := udp4pkt(b, udpAddr, upc.boundAddr) + + // Broadcasting is not always right, but hell, what the ARP do I know. + return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac}) +} diff --git a/dhcpd/nclient4/ipv4.go b/dhcpd/nclient4/ipv4.go new file mode 100644 index 00000000..5733eb46 --- /dev/null +++ b/dhcpd/nclient4/ipv4.go @@ -0,0 +1,376 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This file contains code taken from gVisor. + +// +build go1.12 + +package nclient4 + +import ( + "encoding/binary" + "net" + + "github.com/u-root/u-root/pkg/uio" +) + +const ( + versIHL = 0 + tos = 1 + totalLen = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 +) + +// TransportProtocolNumber is the number of a transport protocol. +type TransportProtocolNumber uint32 + +// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv4Fields struct { + // IHL is the "internet header length" field of an IPv4 packet. + IHL uint8 + + // TOS is the "type of service" field of an IPv4 packet. + TOS uint8 + + // TotalLength is the "total length" field of an IPv4 packet. + TotalLength uint16 + + // ID is the "identification" field of an IPv4 packet. + ID uint16 + + // Flags is the "flags" field of an IPv4 packet. + Flags uint8 + + // FragmentOffset is the "fragment offset" field of an IPv4 packet. + FragmentOffset uint16 + + // TTL is the "time to live" field of an IPv4 packet. + TTL uint8 + + // Protocol is the "protocol" field of an IPv4 packet. + Protocol uint8 + + // Checksum is the "checksum" field of an IPv4 packet. + Checksum uint16 + + // SrcAddr is the "source ip address" of an IPv4 packet. + SrcAddr net.IP + + // DstAddr is the "destination ip address" of an IPv4 packet. + DstAddr net.IP +} + +// IPv4 represents an ipv4 header stored in a byte array. +// Most of the methods of IPv4 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv4 before using other methods. +type IPv4 []byte + +const ( + // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + IPv4MinimumSize = 20 + + // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given + // that there are only 4 bits to represents the header length in 32-bit + // units, the header cannot exceed 15*4 = 60 bytes. + IPv4MaximumHeaderSize = 60 + + // IPv4AddressSize is the size, in bytes, of an IPv4 address. + IPv4AddressSize = 4 + + // IPv4Version is the version of the ipv4 protocol. + IPv4Version = 4 +) + +var ( + // IPv4Broadcast is the broadcast address of the IPv4 protocol. + IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff} + + // IPv4Any is the non-routable IPv4 "any" meta address. + IPv4Any = net.IP{0, 0, 0, 0} +) + +// Flags that may be set in an IPv4 packet. +const ( + IPv4FlagMoreFragments = 1 << iota + IPv4FlagDontFragment +) + +// HeaderLength returns the value of the "header length" field of the ipv4 +// header. +func (b IPv4) HeaderLength() uint8 { + return (b[versIHL] & 0xf) * 4 +} + +// Protocol returns the value of the protocol field of the ipv4 header. +func (b IPv4) Protocol() uint8 { + return b[protocol] +} + +// SourceAddress returns the "source address" field of the ipv4 header. +func (b IPv4) SourceAddress() net.IP { + return net.IP(b[srcAddr : srcAddr+IPv4AddressSize]) +} + +// DestinationAddress returns the "destination address" field of the ipv4 +// header. +func (b IPv4) DestinationAddress() net.IP { + return net.IP(b[dstAddr : dstAddr+IPv4AddressSize]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv4) TransportProtocol() TransportProtocolNumber { + return TransportProtocolNumber(b.Protocol()) +} + +// Payload implements Network.Payload. +func (b IPv4) Payload() []byte { + return b[b.HeaderLength():][:b.PayloadLength()] +} + +// PayloadLength returns the length of the payload portion of the ipv4 packet. +func (b IPv4) PayloadLength() uint16 { + return b.TotalLength() - uint16(b.HeaderLength()) +} + +// TotalLength returns the "total length" field of the ipv4 header. +func (b IPv4) TotalLength() uint16 { + return binary.BigEndian.Uint16(b[totalLen:]) +} + +// SetTotalLength sets the "total length" field of the ipv4 header. +func (b IPv4) SetTotalLength(totalLength uint16) { + binary.BigEndian.PutUint16(b[totalLen:], totalLength) +} + +// SetChecksum sets the checksum field of the ipv4 header. +func (b IPv4) SetChecksum(v uint16) { + binary.BigEndian.PutUint16(b[checksum:], v) +} + +// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the +// ipv4 header. +func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { + v := (uint16(flags) << 13) | (offset >> 3) + binary.BigEndian.PutUint16(b[flagsFO:], v) +} + +// SetSourceAddress sets the "source address" field of the ipv4 header. +func (b IPv4) SetSourceAddress(addr net.IP) { + copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4()) +} + +// SetDestinationAddress sets the "destination address" field of the ipv4 +// header. +func (b IPv4) SetDestinationAddress(addr net.IP) { + copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4()) +} + +// CalculateChecksum calculates the checksum of the ipv4 header. +func (b IPv4) CalculateChecksum() uint16 { + return Checksum(b[:b.HeaderLength()], 0) +} + +// Encode encodes all the fields of the ipv4 header. +func (b IPv4) Encode(i *IPv4Fields) { + b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b[tos] = i.TOS + b.SetTotalLength(i.TotalLength) + binary.BigEndian.PutUint16(b[id:], i.ID) + b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) + b[ttl] = i.TTL + b[protocol] = i.Protocol + b.SetChecksum(i.Checksum) + copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr) + copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) +} + +const ( + udpSrcPort = 0 + udpDstPort = 2 + udpLength = 4 + udpChecksum = 6 +) + +// UDPFields contains the fields of a UDP packet. It is used to describe the +// fields of a packet that needs to be encoded. +type UDPFields struct { + // SrcPort is the "source port" field of a UDP packet. + SrcPort uint16 + + // DstPort is the "destination port" field of a UDP packet. + DstPort uint16 + + // Length is the "length" field of a UDP packet. + Length uint16 + + // Checksum is the "checksum" field of a UDP packet. + Checksum uint16 +} + +// UDP represents a UDP header stored in a byte array. +type UDP []byte + +const ( + // UDPMinimumSize is the minimum size of a valid UDP packet. + UDPMinimumSize = 8 + + // UDPProtocolNumber is UDP's transport protocol number. + UDPProtocolNumber TransportProtocolNumber = 17 +) + +// SourcePort returns the "source port" field of the udp header. +func (b UDP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[udpSrcPort:]) +} + +// DestinationPort returns the "destination port" field of the udp header. +func (b UDP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[udpDstPort:]) +} + +// Length returns the "length" field of the udp header. +func (b UDP) Length() uint16 { + return binary.BigEndian.Uint16(b[udpLength:]) +} + +// SetSourcePort sets the "source port" field of the udp header. +func (b UDP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[udpSrcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the udp header. +func (b UDP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[udpDstPort:], port) +} + +// SetChecksum sets the "checksum" field of the udp header. +func (b UDP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[udpChecksum:], checksum) +} + +// Payload returns the data contained in the UDP datagram. +func (b UDP) Payload() []byte { + return b[UDPMinimumSize:] +} + +// Checksum returns the "checksum" field of the udp header. +func (b UDP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[udpChecksum:]) +} + +// CalculateChecksum calculates the checksum of the udp packet, given the total +// length of the packet and the checksum of the network-layer pseudo-header +// (excluding the total length) and the checksum of the payload. +func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { + // Add the length portion of the checksum to the pseudo-checksum. + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + checksum := Checksum(tmp, partialChecksum) + + // Calculate the rest of the checksum. + return Checksum(b[:UDPMinimumSize], checksum) +} + +// Encode encodes all the fields of the udp header. +func (b UDP) Encode(u *UDPFields) { + binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) + binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) + binary.BigEndian.PutUint16(b[udpLength:], u.Length) + binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) +} + +func calculateChecksum(buf []byte, initial uint32) uint16 { + v := initial + + l := len(buf) + if l&1 != 0 { + l-- + v += uint32(buf[l]) << 8 + } + + for i := 0; i < l; i += 2 { + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)) +} + +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. +// +// The initial checksum must have been computed on an even number of bytes. +func Checksum(buf []byte, initial uint16) uint16 { + return calculateChecksum(buf, uint32(initial)) +} + +// ChecksumCombine combines the two uint16 to form their checksum. This is done +// by adding them and the carry. +// +// Note that checksum a must have been computed on an even number of bytes. +func ChecksumCombine(a, b uint16) uint16 { + v := uint32(a) + uint32(b) + return uint16(v + v>>16) +} + +// PseudoHeaderChecksum calculates the pseudo-header checksum for the +// given destination protocol and network address, ignoring the length +// field. Pseudo-headers are needed by transport layers when calculating +// their own checksum. +func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 { + xsum := Checksum([]byte(srcAddr), 0) + xsum = Checksum([]byte(dstAddr), xsum) + return Checksum([]byte{0, uint8(protocol)}, xsum) +} + +func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte { + ipLen := IPv4MinimumSize + udpLen := UDPMinimumSize + + h := make([]byte, 0, ipLen+udpLen+len(packet)) + hdr := uio.NewBigEndianBuffer(h) + + ipv4fields := &IPv4Fields{ + IHL: IPv4MinimumSize, + TotalLength: uint16(ipLen + udpLen + len(packet)), + TTL: 64, // Per RFC 1700's recommendation for IP time to live + Protocol: uint8(UDPProtocolNumber), + SrcAddr: src.IP.To4(), + DstAddr: dest.IP.To4(), + } + ipv4hdr := IPv4(hdr.WriteN(ipLen)) + ipv4hdr.Encode(ipv4fields) + ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum()) + + udphdr := UDP(hdr.WriteN(udpLen)) + udphdr.Encode(&UDPFields{ + SrcPort: uint16(src.Port), + DstPort: uint16(dest.Port), + Length: uint16(udpLen + len(packet)), + }) + + xsum := Checksum(packet, PseudoHeaderChecksum( + ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr)) + udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length())) + + hdr.WriteBytes(packet) + return hdr.Data() +} diff --git a/go.mod b/go.mod index fc774a4f..7bae4404 100644 --- a/go.mod +++ b/go.mod @@ -7,23 +7,25 @@ require ( github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/urlfilter v0.11.2 github.com/NYTimes/gziphandler v1.1.1 - github.com/fsnotify/fsnotify v1.4.7 + github.com/fsnotify/fsnotify v1.4.9 github.com/gobuffalo/packr v1.30.1 - github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 // indirect - github.com/insomniacslk/dhcp v0.0.0-20200420235442-ed3125c2efe7 + github.com/google/go-cmp v0.4.0 // indirect + github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 + github.com/insomniacslk/dhcp v0.0.0-20200621044212-d74cd86ad5b8 github.com/joomcode/errorx v1.0.1 github.com/kardianos/service v1.1.0 - github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 // indirect - github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 // indirect + github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 + github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 github.com/miekg/dns v1.1.29 github.com/pkg/errors v0.9.1 + github.com/sirupsen/logrus v1.6.0 // indirect github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c github.com/stretchr/testify v1.5.1 - github.com/u-root/u-root v6.0.0+incompatible // indirect + github.com/u-root/u-root v6.0.0+incompatible go.etcd.io/bbolt v1.3.4 - golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 - golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e - golang.org/x/sys v0.0.0-20200331124033-c3d80250170d + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/net v0.0.0-20200625001655-4c5254603344 + golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 gopkg.in/natefinch/lumberjack.v2 v2.0.0 - gopkg.in/yaml.v2 v2.2.8 + gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index 2df1d656..d043c5ec 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-test/deep v1.0.5 h1:AKODKU3pDH1RzZzm6YZu77YWtEAq6uh1rLIAQlay2qc= @@ -52,14 +54,16 @@ github.com/gobuffalo/packr/v2 v2.5.1/go.mod h1:8f9c96ITobJlPzI44jj+4tHnEKNt0xXWS github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8= github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/insomniacslk/dhcp v0.0.0-20200420235442-ed3125c2efe7 h1:iaCm+9nZdYb8XCSU2TfIb0qYTcAlIv2XzyKR2d2xZ38= -github.com/insomniacslk/dhcp v0.0.0-20200420235442-ed3125c2efe7/go.mod h1:CfMdguCK66I5DAUJgGKyNz8aB6vO5dZzkm9Xep6WGvw= +github.com/insomniacslk/dhcp v0.0.0-20200621044212-d74cd86ad5b8 h1:u+vle+5E78+cT/CSMD5/Y3NUpMgA83Yu2KhG+Zbco/k= +github.com/insomniacslk/dhcp v0.0.0-20200621044212-d74cd86ad5b8/go.mod h1:CfMdguCK66I5DAUJgGKyNz8aB6vO5dZzkm9Xep6WGvw= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= @@ -73,6 +77,8 @@ github.com/karrick/godirwalk v1.10.12/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0L github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -105,6 +111,8 @@ github.com/shirou/gopsutil v2.20.3+incompatible h1:0JVooMPsT7A7HqEYdydp/OfjSOYSj github.com/shirou/gopsutil v2.20.3+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c h1:gqEdF4VwBu3lTKGHS9rXE9x1/pEaSwCXRLOZRF6qtlw= github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c/go.mod h1:eMyUVp6f/5jnzM+3zahzl7q6UXLbgSc3MKg/+ow9QW0= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= @@ -135,6 +143,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 h1:fpnn/HnJONpIu6hkXi1u/7rR0NzilgWr4T0JmWkEitk= golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -145,6 +155,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a h1:WXEvlFVvvGxCJLG6REjsT03iWnKLEWinaScsxF2Vm2o= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -157,10 +169,13 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190515120540-06a5c4944438/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606122018-79a91cf218c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200331124033-c3d80250170d h1:nc5K6ox/4lTFbMVSL9WRR81ixkcwXThoiF6yf+R9scA= golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 h1:DYfZAGf2WMFjMxbgTjaC+2HC7NkNAQs+6Q8b9WEB/F4= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -169,6 +184,8 @@ golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= @@ -179,3 +196,5 @@ gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=