diff --git a/go.mod b/go.mod index c6eb56a2..d31b9dfc 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/stretchr/testify v1.6.1 github.com/ti-mo/netfilter v0.4.0 github.com/u-root/u-root v7.0.0+incompatible + github.com/willf/bitset v1.1.11 go.etcd.io/bbolt v1.3.5 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 diff --git a/go.sum b/go.sum index c26a26cd..1d9bebfb 100644 --- a/go.sum +++ b/go.sum @@ -425,6 +425,8 @@ github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGr github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/willf/bitset v1.1.11 h1:N7Z7E9UvjW+sGsEl7k/SJrvY2reP1A07MrGuCjIOjRE= +github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= diff --git a/internal/agherr/agherr.go b/internal/agherr/agherr.go index aedf2a8b..eb61206f 100644 --- a/internal/agherr/agherr.go +++ b/internal/agherr/agherr.go @@ -1,5 +1,4 @@ -// Package agherr contains the extended error type, and the function for -// wrapping several errors. +// Package agherr contains AdGuard Home's error handling helpers. package agherr import ( @@ -23,8 +22,10 @@ type manyError struct { } // Many wraps several errors and returns a single error. -func Many(message string, underlying ...error) error { - err := &manyError{ +// +// TODO(a.garipov): Add formatting to message. +func Many(message string, underlying ...error) (err error) { + err = &manyError{ message: message, underlying: underlying, } @@ -33,7 +34,7 @@ func Many(message string, underlying ...error) error { } // Error implements the error interface for *manyError. -func (e *manyError) Error() string { +func (e *manyError) Error() (msg string) { switch len(e.underlying) { case 0: return e.message @@ -58,7 +59,7 @@ func (e *manyError) Error() string { } // Unwrap implements the hidden errors.wrapper interface for *manyError. -func (e *manyError) Unwrap() error { +func (e *manyError) Unwrap() (err error) { if len(e.underlying) == 0 { return nil } @@ -71,3 +72,38 @@ func (e *manyError) Unwrap() error { type wrapper interface { Unwrap() error } + +// Annotate annotates the error with the message, unless the error is nil. This +// is a helper function to simplify code like this: +// +// func (f *foo) doStuff(s string) (err error) { +// defer func() { +// if err != nil { +// err = fmt.Errorf("bad foo string %q: %w", s, err) +// } +// }() +// +// // … +// } +// +// Instead, write: +// +// func (f *foo) doStuff(s string) (err error) { +// defer agherr.Annotate("bad foo string %q: %w", &err, s) +// +// // … +// } +// +// msg must contain the final ": %w" verb. +func Annotate(msg string, errPtr *error, args ...interface{}) { + if errPtr == nil { + return + } + + err := *errPtr + if err != nil { + args = append(args, err) + + *errPtr = fmt.Errorf(msg, args...) + } +} diff --git a/internal/agherr/agherr_test.go b/internal/agherr/agherr_test.go index 3ac5aeab..b9f3183c 100644 --- a/internal/agherr/agherr_test.go +++ b/internal/agherr/agherr_test.go @@ -6,30 +6,32 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestError_Error(t *testing.T) { testCases := []struct { + err error name string want string - err error }{{ + err: Many("a"), name: "simple", want: "a", - err: Many("a"), }, { + err: Many("a", errors.New("b")), name: "wrapping", want: "a: b", - err: Many("a", errors.New("b")), }, { + err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")), name: "wrapping several", want: "a: b (hidden: c, d)", - err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")), }, { + err: Many("a", Many("b", errors.New("c"), errors.New("d"))), name: "wrapping wrapper", want: "a: b: c (hidden: d)", - err: Many("a", Many("b", errors.New("c"), errors.New("d"))), }} + for _, tc := range testCases { assert.Equal(t, tc.want, tc.err.Error(), tc.name) } @@ -43,33 +45,78 @@ func TestError_Unwrap(t *testing.T) { errWrapped errNil ) + errs := []error{ errSimple: errors.New("a"), errWrapped: fmt.Errorf("err: %w", errors.New("nested")), errNil: nil, } + testCases := []struct { - name string want error wrapped error + name string }{{ - name: "simple", want: errs[errSimple], wrapped: Many("a", errs[errSimple]), + name: "simple", }, { - name: "nested", want: errs[errWrapped], wrapped: Many("b", errs[errWrapped]), + name: "nested", }, { - name: "nil passed", want: errs[errNil], wrapped: Many("c", errs[errNil]), + name: "nil passed", }, { - name: "nil not passed", want: nil, wrapped: Many("d"), + name: "nil not passed", }} + for _, tc := range testCases { assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name) } } + +func TestAnnotate(t *testing.T) { + const s = "1234" + const wantMsg = `bad string "1234": test` + + // Don't use const, because we can't take a pointer of a constant. + var errTest error = Error("test") + + t.Run("nil", func(t *testing.T) { + var errPtr *error + assert.NotPanics(t, func() { + Annotate("bad string %q: %w", errPtr, s) + }) + }) + + t.Run("non_nil", func(t *testing.T) { + errPtr := &errTest + assert.NotPanics(t, func() { + Annotate("bad string %q: %w", errPtr, s) + }) + + require.NotNil(t, errPtr) + + err := *errPtr + require.NotNil(t, err) + + assert.Equal(t, wantMsg, err.Error()) + }) + + t.Run("defer", func(t *testing.T) { + f := func() (err error) { + defer Annotate("bad string %q: %w", &errTest, s) + + return errTest + } + + err := f() + require.NotNil(t, err) + + assert.Equal(t, wantMsg, err.Error()) + }) +} diff --git a/internal/dhcpd/iprange.go b/internal/dhcpd/iprange.go new file mode 100644 index 00000000..50242255 --- /dev/null +++ b/internal/dhcpd/iprange.go @@ -0,0 +1,99 @@ +package dhcpd + +import ( + "fmt" + "math" + "math/big" + "net" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" +) + +// ipRange is an inclusive range of IP addresses. +// +// It is safe for concurrent use. +// +// TODO(a.garipov): Perhaps create an optimised version with uint32 for +// IPv4 ranges? Or use one of uint128 packages? +type ipRange struct { + start *big.Int + end *big.Int +} + +// maxRangeLen is the maximum IP range length. The bitsets used in servers only +// accept uints, which can have the size of 32 bit. +const maxRangeLen = math.MaxUint32 + +// newIPRange creates a new IP address range. start must be less than end. The +// resulting range must not be greater than maxRangeLen. +func newIPRange(start, end net.IP) (r *ipRange, err error) { + defer agherr.Annotate("invalid ip range: %w", &err) + + // Make sure that both are 16 bytes long to simplify handling in + // methods. + start, end = start.To16(), end.To16() + + startInt := (&big.Int{}).SetBytes(start) + endInt := (&big.Int{}).SetBytes(end) + diff := (&big.Int{}).Sub(endInt, startInt) + + if diff.Sign() <= 0 { + return nil, fmt.Errorf("start is greater than or equal to end") + } else if !diff.IsUint64() || diff.Uint64() > maxRangeLen { + return nil, fmt.Errorf("range is too large") + } + + r = &ipRange{ + start: startInt, + end: endInt, + } + + return r, nil +} + +// contains returns true if r contains ip. +func (r *ipRange) contains(ip net.IP) (ok bool) { + ipInt := (&big.Int{}).SetBytes(ip.To16()) + + return r.containsInt(ipInt) +} + +// containsInt returns true if r contains ipInt. +func (r *ipRange) containsInt(ipInt *big.Int) (ok bool) { + return ipInt.Cmp(r.start) >= 0 && ipInt.Cmp(r.end) <= 0 +} + +// ipPredicate is a function that is called on every IP address in +// (*ipRange).find. ip is given in the 16-byte form. +type ipPredicate func(ip net.IP) (ok bool) + +// find finds the first IP address in r for which p returns true. ip is in the +// 16-byte form. +func (r *ipRange) find(p ipPredicate) (ip net.IP) { + ip = make(net.IP, net.IPv6len) + _1 := big.NewInt(1) + for i := (&big.Int{}).Set(r.start); i.Cmp(r.end) <= 0; i.Add(i, _1) { + i.FillBytes(ip) + if p(ip) { + return ip + } + } + + return nil +} + +// offset returns the offset of ip from the beginning of r. It returns 0 and +// false if ip is not in r. +func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) { + ip = ip.To16() + ipInt := (&big.Int{}).SetBytes(ip) + if !r.containsInt(ipInt) { + return 0, false + } + + offsetInt := (&big.Int{}).Sub(ipInt, r.start) + + // Assume that the range was checked against maxRangeLen during + // construction. + return uint(offsetInt.Uint64()), true +} diff --git a/internal/dhcpd/iprange_test.go b/internal/dhcpd/iprange_test.go new file mode 100644 index 00000000..01991532 --- /dev/null +++ b/internal/dhcpd/iprange_test.go @@ -0,0 +1,154 @@ +package dhcpd + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewIPRange(t *testing.T) { + start4 := net.IP{0, 0, 0, 1} + end4 := net.IP{0, 0, 0, 3} + start6 := net.IP{ + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + } + end6 := net.IP{ + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x03, + } + end6Large := net.IP{ + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x03, + } + + testCases := []struct { + name string + wantErrMsg string + start net.IP + end net.IP + }{{ + name: "success_ipv4", + wantErrMsg: "", + start: start4, + end: end4, + }, { + name: "success_ipv6", + wantErrMsg: "", + start: start6, + end: end6, + }, { + name: "start_gt_end", + wantErrMsg: "invalid ip range: start is greater than or equal to end", + start: end4, + end: start4, + }, { + name: "start_eq_end", + wantErrMsg: "invalid ip range: start is greater than or equal to end", + start: start4, + end: start4, + }, { + name: "too_large", + wantErrMsg: "invalid ip range: range is too large", + start: start6, + end: end6Large, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r, err := newIPRange(tc.start, tc.end) + if tc.wantErrMsg == "" { + assert.Nil(t, err) + assert.NotNil(t, r) + } else { + require.NotNil(t, err) + assert.Equal(t, tc.wantErrMsg, err.Error()) + } + }) + } +} + +func TestIPRange_Contains(t *testing.T) { + start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 3} + r, err := newIPRange(start, end) + require.Nil(t, err) + + assert.True(t, r.contains(start)) + assert.True(t, r.contains(net.IP{0, 0, 0, 2})) + assert.True(t, r.contains(end)) + + assert.False(t, r.contains(net.IP{0, 0, 0, 0})) + assert.False(t, r.contains(net.IP{0, 0, 0, 4})) +} + +func TestIPRange_Find(t *testing.T) { + start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5} + r, err := newIPRange(start, end) + require.Nil(t, err) + + want := net.IPv4(0, 0, 0, 2) + got := r.find(func(ip net.IP) (ok bool) { + return ip[len(ip)-1]%2 == 0 + }) + + assert.Equal(t, want, got) + + got = r.find(func(ip net.IP) (ok bool) { + return ip[len(ip)-1]%10 == 0 + }) + assert.Nil(t, got) +} + +func TestIPRange_Offset(t *testing.T) { + start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5} + r, err := newIPRange(start, end) + require.Nil(t, err) + + testCases := []struct { + name string + in net.IP + wantOffset uint + wantOK bool + }{{ + name: "in", + in: net.IP{0, 0, 0, 2}, + wantOffset: 1, + wantOK: true, + }, { + name: "in_start", + in: start, + wantOffset: 0, + wantOK: true, + }, { + name: "in_end", + in: end, + wantOffset: 4, + wantOK: true, + }, { + name: "out_after", + in: net.IP{0, 0, 0, 6}, + wantOffset: 0, + wantOK: false, + }, { + name: "out_before", + in: net.IP{0, 0, 0, 0}, + wantOffset: 0, + wantOK: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + offset, ok := r.offset(tc.in) + assert.Equal(t, tc.wantOffset, offset) + assert.Equal(t, tc.wantOK, ok) + }) + } +} diff --git a/internal/dhcpd/options.go b/internal/dhcpd/options.go index 9992764e..780eeeab 100644 --- a/internal/dhcpd/options.go +++ b/internal/dhcpd/options.go @@ -100,11 +100,7 @@ func newDHCPOptionParser() (p *dhcpOptionParser) { // parse parses an option. See the handlers' documentation for more info. func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("invalid option string %q: %w", s, err) - } - }() + defer agherr.Annotate("invalid option string %q: %w", &err, s) s = strings.TrimSpace(s) parts := strings.SplitN(s, " ", 3) diff --git a/internal/dhcpd/options_test.go b/internal/dhcpd/options_test.go index 411aec65..63f736d1 100644 --- a/internal/dhcpd/options_test.go +++ b/internal/dhcpd/options_test.go @@ -8,7 +8,7 @@ import ( ) func TestDHCPOptionParser(t *testing.T) { - testCasesA := []struct { + testCases := []struct { name string in string wantErrMsg string @@ -92,7 +92,7 @@ func TestDHCPOptionParser(t *testing.T) { p := newDHCPOptionParser() - for _, tc := range testCasesA { + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { code, data, err := p.parse(tc.in) if tc.wantErrMsg == "" { diff --git a/internal/dhcpd/server.go b/internal/dhcpd/server.go index a1f6f2f6..945c5dec 100644 --- a/internal/dhcpd/server.go +++ b/internal/dhcpd/server.go @@ -60,8 +60,8 @@ type V4ServerConf struct { // DEC_CODE ip IP_ADDR Options []string `yaml:"options" json:"-"` - ipStart net.IP // starting IP address for dynamic leases - ipEnd net.IP // ending IP address for dynamic leases + ipRange *ipRange + leaseTime time.Duration // the time during which a dynamic lease is considered valid dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses routerIP net.IP // value for Option Router diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 37341974..a4d7c5c8 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -4,7 +4,6 @@ package dhcpd import ( "bytes" - "encoding/binary" "fmt" "net" "sync" @@ -14,19 +13,25 @@ import ( "github.com/go-ping/ping" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/server4" + "github.com/willf/bitset" ) // v4Server is a DHCPv4 server. // // TODO(a.garipov): Think about unifying this and v6Server. type v4Server struct { - srv *server4.Server - leasesLock sync.Mutex - leases []*Lease - // TODO(e.burkov): This field type should be a normal bitmap. - ipAddrs [256]byte - conf V4ServerConf + srv *server4.Server + + // leasedOffsets contains offsets from conf.ipRange.start that have been + // leased. + leasedOffsets *bitset.BitSet + + // leases contains all dynamic and static leases. + leases []*Lease + + // leasesLock protects leases and leasedOffsets. + leasesLock sync.Mutex } // WriteDiskConfig4 - write configuration @@ -38,27 +43,14 @@ func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) { func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) { } -// Return TRUE if IP address is within range [start..stop] -func ip4InRange(start, stop, ip net.IP) bool { - if len(start) != 4 || len(stop) != 4 { - return false - } - from := binary.BigEndian.Uint32(start) - to := binary.BigEndian.Uint32(stop) - check := binary.BigEndian.Uint32(ip) - return from <= check && check <= to -} - // ResetLeases - reset leases func (s *v4Server) ResetLeases(leases []*Lease) { s.leases = nil for _, l := range leases { + if l.Expiry.Unix() != leaseExpireStatic && !s.conf.ipRange.contains(l.IP) { + log.Debug("dhcpv4: skipping a lease with ip %v: not within current ip range", l.IP) - if l.Expiry.Unix() != leaseExpireStatic && - !ip4InRange(s.conf.ipStart, s.conf.ipEnd, l.IP) { - - log.Debug("dhcpv4: skipping a lease with IP %v: not within current IP range", l.IP) continue } @@ -127,16 +119,18 @@ func (s *v4Server) blacklistLease(lease *Lease) { lease.Expiry = time.Now().Add(s.conf.leaseTime) } -// Remove (swap) lease by index -func (s *v4Server) leaseRemoveSwapByIndex(i int) { - s.ipAddrs[s.leases[i].IP[3]] = 0 - log.Debug("dhcpv4: removed lease %s", s.leases[i].HWAddr) +// rmLeaseByIndex removes a lease by its index in the leases slice. +func (s *v4Server) rmLeaseByIndex(i int) { + l := s.leases[i] + s.leases = append(s.leases[:i], s.leases[i+1:]...) - n := len(s.leases) - if i != n-1 { - s.leases[i] = s.leases[n-1] // swap with the last element + r := s.conf.ipRange + offset, ok := r.offset(l.IP) + if ok { + s.leasedOffsets.Clear(offset) } - s.leases = s.leases[:n-1] + + log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr) } // Remove a dynamic lease with the same properties @@ -146,51 +140,61 @@ func (s *v4Server) rmDynamicLease(lease Lease) error { l := s.leases[i] if bytes.Equal(l.HWAddr, lease.HWAddr) { - if l.Expiry.Unix() == leaseExpireStatic { return fmt.Errorf("static lease already exists") } - s.leaseRemoveSwapByIndex(i) + s.rmLeaseByIndex(i) if i == len(s.leases) { break } + l = s.leases[i] } if net.IP.Equal(l.IP, lease.IP) { - if l.Expiry.Unix() == leaseExpireStatic { return fmt.Errorf("static lease already exists") } - s.leaseRemoveSwapByIndex(i) + s.rmLeaseByIndex(i) } } return nil } -// Add a lease +// addLease adds a lease. func (s *v4Server) addLease(l *Lease) { + r := s.conf.ipRange + offset, ok := r.offset(l.IP) + if !ok { + // TODO(a.garipov): Better error handling. + log.Debug("dhcpv4: lease %s (%s) out of range, not adding", l.IP, l.HWAddr) + + return + } + s.leases = append(s.leases, l) - s.ipAddrs[l.IP[3]] = 1 - log.Debug("dhcpv4: added lease %s <-> %s", l.IP, l.HWAddr) + s.leasedOffsets.Set(uint(offset)) + + log.Debug("dhcpv4: added lease %s (%s)", l.IP, l.HWAddr) } // Remove a lease with the same properties func (s *v4Server) rmLease(lease Lease) error { for i, l := range s.leases { - if net.IP.Equal(l.IP, lease.IP) { - + if l.IP.Equal(lease.IP) { if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname { return fmt.Errorf("lease not found") } - s.leaseRemoveSwapByIndex(i) + s.rmLeaseByIndex(i) + return nil } } + return fmt.Errorf("lease not found") } @@ -258,7 +262,7 @@ func (s *v4Server) addrAvailable(target net.IP) bool { pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond pinger.Count = 1 reply := false - pinger.OnRecv = func(pkt *ping.Packet) { + pinger.OnRecv = func(_ *ping.Packet) { reply = true } log.Debug("dhcpv4: Sending ICMP Echo to %v", target) @@ -278,30 +282,31 @@ func (s *v4Server) addrAvailable(target net.IP) bool { return true } -// Find lease by MAC -func (s *v4Server) findLease(mac net.HardwareAddr) *Lease { - for i := range s.leases { - if bytes.Equal(mac, s.leases[i].HWAddr) { - return s.leases[i] +// findLease finds a lease by its MAC-address. +func (s *v4Server) findLease(mac net.HardwareAddr) (l *Lease) { + for _, l = range s.leases { + if bytes.Equal(mac, l.HWAddr) { + return l } } + return nil } -// Get next free IP -func (s *v4Server) findFreeIP() net.IP { - for i := s.conf.ipStart[3]; ; i++ { - if s.ipAddrs[i] == 0 { - ip := make([]byte, 4) - copy(ip, s.conf.ipStart) - ip[3] = i - return ip +// nextIP generates a new free IP. +func (s *v4Server) nextIP() (ip net.IP) { + r := s.conf.ipRange + ip = r.find(func(next net.IP) (ok bool) { + offset, ok := r.offset(next) + if !ok { + // Shouldn't happen. + return false } - if i == s.conf.ipEnd[3] { - break - } - } - return nil + + return !s.leasedOffsets.Test(uint(offset)) + }) + + return ip.To4() } // Find an expired lease and return its index or -1 @@ -316,24 +321,30 @@ func (s *v4Server) findExpiredLease() int { return -1 } -// Reserve lease for MAC -func (s *v4Server) reserveLease(mac net.HardwareAddr) *Lease { - l := Lease{} - l.HWAddr = make([]byte, 6) +// reserveLease reserves a lease for a client by its MAC-address. It returns +// nil if it couldn't allocate a new lease. +func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease) { + l = &Lease{ + HWAddr: make([]byte, 6), + } + copy(l.HWAddr, mac) - l.IP = s.findFreeIP() + l.IP = s.nextIP() if l.IP == nil { i := s.findExpiredLease() if i < 0 { return nil } + copy(s.leases[i].HWAddr, mac) + return s.leases[i] } - s.addLease(&l) - return &l + s.addLease(l) + + return l } func (s *v4Server) commitLease(l *Lease) { @@ -650,22 +661,12 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { s.conf.subnetMask = make([]byte, 4) copy(s.conf.subnetMask, s.conf.SubnetMask.To4()) - s.conf.ipStart, err = tryTo4(conf.RangeStart) - if s.conf.ipStart == nil { + s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd) + if err != nil { return s, fmt.Errorf("dhcpv4: %w", err) } - if s.conf.ipStart[0] == 0 { - return s, fmt.Errorf("dhcpv4: invalid range start IP") - } - s.conf.ipEnd, err = tryTo4(conf.RangeEnd) - if s.conf.ipEnd == nil { - return s, fmt.Errorf("dhcpv4: %w", err) - } - if !net.IP.Equal(s.conf.ipStart[:3], s.conf.ipEnd[:3]) || - s.conf.ipStart[3] > s.conf.ipEnd[3] { - return s, fmt.Errorf("dhcpv4: range end IP should match range start IP") - } + s.leasedOffsets = &bitset.BitSet{} if conf.LeaseDuration == 0 { s.conf.leaseTime = time.Hour * 24 diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index 51ce92c2..9c589742 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -212,18 +212,26 @@ func TestV4DynamicLease_Get(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, s.process(req, resp)) }) + + // Don't continue if we got any errors in the previous subtest. require.Nil(t, err) t.Run("offer", func(t *testing.T) { assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, mac, resp.ClientHWAddr) - assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr)) - assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) - assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) + + assert.Equal(t, s.conf.RangeStart, resp.YourIPAddr) + assert.Equal(t, s.conf.GatewayIP, resp.ServerIdentifier()) + + router := resp.Router() + require.Len(t, router, 1) + assert.Equal(t, s.conf.GatewayIP, router[0]) + assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]))) + + assert.Equal(t, net.IP{1, 2, 3, 4}, net.IP(resp.RelayAgentInfo().ToBytes())) }) t.Run("request", func(t *testing.T) { @@ -260,31 +268,3 @@ func TestV4DynamicLease_Get(t *testing.T) { assert.Equal(t, mac, ls[0].HWAddr) }) } - -func TestIP4InRange(t *testing.T) { - start := net.IP{192, 168, 10, 100} - stop := net.IP{192, 168, 10, 200} - - testCases := []struct { - ip net.IP - want bool - }{{ - ip: net.IP{192, 168, 10, 99}, - want: false, - }, { - ip: net.IP{192, 168, 11, 100}, - want: false, - }, { - ip: net.IP{192, 168, 11, 201}, - want: false, - }, { - ip: start, - want: true, - }} - - for _, tc := range testCases { - t.Run(tc.ip.String(), func(t *testing.T) { - assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip)) - }) - } -} diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index bfd0c11b..85eb866e 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -231,8 +231,17 @@ func TestClientsAddExisting(t *testing.T) { // First, init a DHCP server with a single static lease. config := dhcpd.ServerConfig{ + Enabled: true, DBFilePath: "leases.db", + Conf4: dhcpd.V4ServerConf{ + Enabled: true, + GatewayIP: net.IP{1, 2, 3, 1}, + SubnetMask: net.IP{255, 255, 255, 0}, + RangeStart: net.IP{1, 2, 3, 2}, + RangeEnd: net.IP{1, 2, 3, 10}, + }, } + clients.dhcpServer = dhcpd.Create(config) t.Cleanup(func() { _ = os.Remove("leases.db") }) diff --git a/staticcheck.conf b/staticcheck.conf index 4dd93176..b997f6a9 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -1,6 +1,8 @@ checks = ["all"] initialisms = [ # See https://github.com/dominikh/go-tools/blob/master/config/config.go. + # + # Do not add "PTR" since we use "Ptr" as a suffix. "inherit" , "DHCP" , "DOH" @@ -8,7 +10,6 @@ initialisms = [ , "DOT" , "EDNS" , "MX" -, "PTR" , "QUIC" , "RA" , "SDNS"