authentik/internal/outpost/radius/radius.go

120 lines
2.5 KiB
Go

package radius
import (
"context"
"net"
"sort"
"sync"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius"
)
type ProviderInstance struct {
ClientNetworks []*net.IPNet
SharedSecret []byte
appSlug string
flowSlug string
s *RadiusServer
log *log.Entry
}
type RadiusServer struct {
s radius.PacketServer
log *log.Entry
ac *ak.APIController
providers []*ProviderInstance
}
func NewServer(ac *ak.APIController) *RadiusServer {
rs := &RadiusServer{
log: log.WithField("logger", "authentik.outpost.radius"),
ac: ac,
providers: []*ProviderInstance{},
}
rs.s = radius.PacketServer{
Handler: rs,
SecretSource: rs,
Addr: config.Get().Listen.Radius,
}
return rs
}
type cpp struct {
c *net.IPNet
p *ProviderInstance
}
func (rs *RadiusServer) RADIUSSecret(ctx context.Context, remoteAddr net.Addr) ([]byte, error) {
matchedPrefixes := []cpp{}
host, _, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
rs.log.WithError(err).Warning("Failed to get remote IP")
return nil, err
}
ip := net.ParseIP(host)
// Check all networks we have and remember all matching prefixes
for _, p := range rs.providers {
for _, cidr := range p.ClientNetworks {
if cidr.Contains(ip) {
matchedPrefixes = append(matchedPrefixes, cpp{
c: cidr,
p: p,
})
}
}
}
if len(matchedPrefixes) < 1 {
rs.log.WithField("ip", ip.String()).Warning("No matching prefixes found")
return []byte{}, nil
}
// Sort matched cidrs by prefix length
sort.Slice(matchedPrefixes, func(i, j int) bool {
_, bi := matchedPrefixes[i].c.Mask.Size()
_, bj := matchedPrefixes[j].c.Mask.Size()
return bi < bj
})
candidate := matchedPrefixes[0]
rs.log.WithField("ip", ip.String()).WithField("cidr", candidate.c.String()).Debug("Matched CIDR")
return candidate.p.SharedSecret, nil
}
func (rs *RadiusServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := rs.StartRadiusServer()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}
func (rs *RadiusServer) Stop() error {
ctx, cancel := context.WithCancel(context.Background())
err := rs.s.Shutdown(ctx)
cancel()
return err
}
func (rs *RadiusServer) TimerFlowCacheExpiry(context.Context) {}
func (rs *RadiusServer) Type() string {
return "radius"
}