diff --git a/dhcpd/db.go b/dhcpd/db.go new file mode 100644 index 00000000..1caa7d81 --- /dev/null +++ b/dhcpd/db.go @@ -0,0 +1,98 @@ +// On-disk database for lease table + +package dhcpd + +import ( + "encoding/json" + "io/ioutil" + "net" + "os" + "time" + + "github.com/AdguardTeam/golibs/file" + "github.com/AdguardTeam/golibs/log" + "github.com/krolaw/dhcp4" +) + +const dbFilename = "leases.db" + +type leaseJSON struct { + HWAddr []byte `json:"mac"` + IP []byte `json:"ip"` + Hostname string `json:"host"` + Expiry int64 `json:"exp"` +} + +// Load lease table from DB +func (s *Server) dbLoad() { + data, err := ioutil.ReadFile(dbFilename) + if err != nil { + if !os.IsNotExist(err) { + log.Error("DHCP: can't read file %s: %v", dbFilename, err) + } + return + } + + obj := []leaseJSON{} + err = json.Unmarshal(data, &obj) + if err != nil { + log.Error("DHCP: invalid DB: %v", err) + return + } + + s.leases = nil + s.IPpool = make(map[[4]byte]net.HardwareAddr) + + numLeases := len(obj) + for i := range obj { + + if !dhcp4.IPInRange(s.leaseStart, s.leaseStop, obj[i].IP) { + log.Tracef("Skipping a lease with IP %s: not within current IP range", obj[i].IP) + continue + } + + lease := Lease{ + HWAddr: obj[i].HWAddr, + IP: obj[i].IP, + Hostname: obj[i].Hostname, + Expiry: time.Unix(obj[i].Expiry, 0), + } + + s.leases = append(s.leases, &lease) + + s.reserveIP(lease.IP, lease.HWAddr) + } + log.Info("DHCP: loaded %d leases from DB", numLeases) +} + +// Store lease table in DB +func (s *Server) dbStore() { + var leases []leaseJSON + + for i := range s.leases { + if s.leases[i].Expiry.Unix() == 0 { + continue + } + lease := leaseJSON{ + HWAddr: s.leases[i].HWAddr, + IP: s.leases[i].IP, + Hostname: s.leases[i].Hostname, + Expiry: s.leases[i].Expiry.Unix(), + } + leases = append(leases, lease) + } + + data, err := json.Marshal(leases) + if err != nil { + log.Error("json.Marshal: %v", err) + return + } + + err = file.SafeWrite(dbFilename, data) + if err != nil { + log.Error("DHCP: can't store lease table on disk: %v filename: %s", + err, dbFilename) + return + } + log.Info("DHCP: stored %d leases in DB", len(leases)) +} diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go index b9affb5a..3920bcef 100644 --- a/dhcpd/dhcpd.go +++ b/dhcpd/dhcpd.go @@ -122,6 +122,8 @@ func (s *Server) Start(config *ServerConfig) error { s.closeConn() } + s.dbLoad() + c, err := newFilterConn(*iface, ":67") // it has to be bound to 0.0.0.0:67, otherwise it won't see DHCP discover/request packets if err != nil { return wrapErrPrint(err, "Couldn't start listening socket on 0.0.0.0:67") @@ -153,6 +155,7 @@ func (s *Server) Stop() error { return wrapErrPrint(err, "Couldn't close UDP listening socket") } + s.dbStore() return nil } diff --git a/dhcpd/dhcpd_test.go b/dhcpd/dhcpd_test.go index 10ca1b59..2d675ed9 100644 --- a/dhcpd/dhcpd_test.go +++ b/dhcpd/dhcpd_test.go @@ -3,6 +3,7 @@ package dhcpd import ( "bytes" "net" + "os" "testing" "time" @@ -113,3 +114,47 @@ func misc(t *testing.T, s *Server) { check(t, bytes.Equal(opt[dhcp4.OptionIPAddressLeaseTime], dhcp4.OptionsLeaseTime(5*time.Second)), "OptionIPAddressLeaseTime") check(t, bytes.Equal(opt[dhcp4.OptionServerIdentifier], s.ipnet.IP), "OptionServerIdentifier") } + +// Leases database store/load +func TestDB(t *testing.T) { + var s = Server{} + var p dhcp4.Packet + var hw1, hw2 net.HardwareAddr + var lease *Lease + + s.leaseStart = []byte{1, 1, 1, 1} + s.leaseStop = []byte{1, 1, 1, 2} + s.leaseTime = 5 * time.Second + s.leaseOptions = dhcp4.Options{} + s.ipnet = &net.IPNet{ + IP: []byte{1, 2, 3, 4}, + Mask: []byte{0xff, 0xff, 0xff, 0xff}, + } + + p = make(dhcp4.Packet, 241) + + hw1 = []byte{1, 2, 3, 4, 5, 6} + p.SetCHAddr(hw1) + lease, _ = s.reserveLease(p) + lease.Expiry = time.Unix(4000000001, 0) + + hw2 = []byte{2, 2, 3, 4, 5, 6} + p.SetCHAddr(hw2) + lease, _ = s.reserveLease(p) + lease.Expiry = time.Unix(4000000002, 0) + + os.Remove("leases.db") + s.dbStore() + s.reset() + + s.dbLoad() + check(t, bytes.Equal(s.leases[0].HWAddr, hw1), "leases[0].HWAddr") + check(t, bytes.Equal(s.leases[0].IP, []byte{1, 1, 1, 1}), "leases[0].IP") + check(t, s.leases[0].Expiry.Unix() == 4000000001, "leases[0].Expiry") + + check(t, bytes.Equal(s.leases[1].HWAddr, hw2), "leases[1].HWAddr") + check(t, bytes.Equal(s.leases[1].IP, []byte{1, 1, 1, 2}), "leases[1].IP") + check(t, s.leases[1].Expiry.Unix() == 4000000002, "leases[1].Expiry") + + os.Remove("leases.db") +}