AdGuardHome/internal/dhcpsvc/server.go

352 lines
8.5 KiB
Go

package dhcpsvc
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/mapsutil"
)
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
type DHCPServer struct {
// enabled indicates whether the DHCP server is enabled and can provide
// information about its clients.
enabled *atomic.Bool
// logger logs common DHCP events.
logger *slog.Logger
// localTLD is the top-level domain name to use for resolving DHCP clients'
// hostnames.
localTLD string
// dbFilePath is the path to the database file containing the DHCP leases.
//
// TODO(e.burkov): Consider extracting the database logic into a separate
// interface to prevent packages that only need lease data from depending on
// the entire server and to simplify testing.
dbFilePath string
// leasesMu protects the leases index as well as leases in the interfaces.
leasesMu *sync.RWMutex
// leases stores the DHCP leases for quick lookups.
leases *leaseIndex
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
interfaces4 dhcpInterfacesV4
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
interfaces6 dhcpInterfacesV6
// icmpTimeout is the timeout for checking another DHCP server's presence.
icmpTimeout time.Duration
}
// New creates a new DHCP server with the given configuration. It returns an
// error if the given configuration can't be used.
//
// TODO(e.burkov): Use.
func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) {
l := conf.Logger
if !conf.Enabled {
l.DebugContext(ctx, "disabled")
// TODO(e.burkov): Perhaps return [Empty]?
return nil, nil
}
ifaces4, ifaces6, err := newInterfaces(ctx, l, conf.Interfaces)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
enabled := &atomic.Bool{}
enabled.Store(conf.Enabled)
srv = &DHCPServer{
enabled: enabled,
logger: l,
localTLD: conf.LocalDomainName,
leasesMu: &sync.RWMutex{},
leases: newLeaseIndex(),
interfaces4: ifaces4,
interfaces6: ifaces6,
icmpTimeout: conf.ICMPTimeout,
dbFilePath: conf.DBFilePath,
}
err = srv.dbLoad(ctx)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
return srv, nil
}
// newInterfaces creates interfaces for the given map of interface names to
// their configurations.
func newInterfaces(
ctx context.Context,
l *slog.Logger,
ifaces map[string]*InterfaceConfig,
) (v4 dhcpInterfacesV4, v6 dhcpInterfacesV6, err error) {
defer func() { err = errors.Annotate(err, "creating interfaces: %w") }()
// TODO(e.burkov): Add validations scoped to the network interfaces set.
v4 = make(dhcpInterfacesV4, 0, len(ifaces))
v6 = make(dhcpInterfacesV6, 0, len(ifaces))
var errs []error
mapsutil.SortedRange(ifaces, func(name string, iface *InterfaceConfig) (cont bool) {
var i4 *dhcpInterfaceV4
i4, err = newDHCPInterfaceV4(ctx, l, name, iface.IPv4)
if err != nil {
errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err))
} else if i4 != nil {
v4 = append(v4, i4)
}
i6 := newDHCPInterfaceV6(ctx, l, name, iface.IPv6)
if i6 != nil {
v6 = append(v6, i6)
}
return true
})
if err = errors.Join(errs...); err != nil {
return nil, nil, err
}
return v4, v6, nil
}
// type check
//
// TODO(e.burkov): Uncomment when the [Interface] interface is implemented.
// var _ Interface = (*DHCPServer)(nil)
// Enabled implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Enabled() (ok bool) {
return srv.enabled.Load()
}
// Leases implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Leases() (leases []*Lease) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
srv.leases.rangeLeases(func(l *Lease) (cont bool) {
leases = append(leases, l.Clone())
return true
})
return leases
}
// HostByIP implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) HostByIP(ip netip.Addr) (host string) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByAddr(ip); ok {
return l.Hostname
}
return ""
}
// MACByIP implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) MACByIP(ip netip.Addr) (mac net.HardwareAddr) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByAddr(ip); ok {
return l.HWAddr
}
return nil
}
// IPByHost implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByName(host); ok {
return l.IP
}
return netip.Addr{}
}
// Reset implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Reset(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "resetting leases: %w") }()
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
srv.resetLeases()
err = srv.dbStore(ctx)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
srv.logger.DebugContext(ctx, "reset leases")
return nil
}
// resetLeases resets the leases for all network interfaces of the server. It
// expects the DHCPServer.leasesMu to be locked.
func (srv *DHCPServer) resetLeases() {
for _, iface := range srv.interfaces4 {
iface.common.reset()
}
for _, iface := range srv.interfaces6 {
iface.common.reset()
}
srv.leases.clear()
}
// AddLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "adding lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
err = srv.leases.add(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = srv.dbStore(ctx)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
iface.logger.DebugContext(
ctx, "added lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// UpdateStaticLease implements the [Interface] interface for *DHCPServer.
//
// TODO(e.burkov): Support moving leases between interfaces.
func (srv *DHCPServer) UpdateStaticLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "updating static lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
err = srv.leases.update(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = srv.dbStore(ctx)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
iface.logger.DebugContext(
ctx, "updated lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// RemoveLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "removing lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
err = srv.leases.remove(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = srv.dbStore(ctx)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
iface.logger.DebugContext(
ctx, "removed lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// ifaceForAddr returns the handled network interface for the given IP address,
// or an error if no such interface exists.
func (srv *DHCPServer) ifaceForAddr(addr netip.Addr) (iface *netInterface, err error) {
var ok bool
if addr.Is4() {
iface, ok = srv.interfaces4.find(addr)
} else {
iface, ok = srv.interfaces6.find(addr)
}
if !ok {
return nil, fmt.Errorf("no interface for ip %s", addr)
}
return iface, nil
}