tailscale/util/linuxfw/nftables_for_svcs.go

246 lines
7.6 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"errors"
"fmt"
"net/netip"
"reflect"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"golang.org/x/sys/unix"
)
// This file contains functionality that is currently (09/2024) used to set up
// routing for the Tailscale Kubernetes operator egress proxies. A tailnet
// service (identified by tailnet IP or FQDN) that gets exposed to cluster
// workloads gets a separate prerouting chain created for it for each IP family
// of the chain's target addresses. Each service's prerouting chain contains one
// or more portmapping rules. A portmapping rule DNATs traffic received on a
// particular port to a port of the tailnet service. Creating a chain per
// service makes it easier to delete a service when no longer needed and helps
// with readability.
// EnsurePortMapRuleForSvc:
// - ensures that nat table exists
// - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table
// - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist)
func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
t, ch, err := n.ensureChainForSvc(svc, targetIP)
if err != nil {
return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
}
meta := svcPortMapRuleMeta(svc, targetIP, pm)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error looking up rule: %w", err)
}
if rule != nil {
return nil
}
p, err := protoFromString(pm.Protocol)
if err != nil {
return fmt.Errorf("error converting protocol %s: %w", pm.Protocol, err)
}
rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta)
n.conn.InsertRule(rule)
return n.conn.Flush()
}
// DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain.
// It finds the matching rule using metadata attached to the rule.
// The caller is expected to call DeleteSvc if the whole service (the chain)
// needs to be deleted, so we don't deal with the case where this is the only
// rule in the chain here.
func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
table, err := n.getNFTByAddr(targetIP)
if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err)
}
t, err := getTableIfExists(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err)
}
if t == nil {
return nil
}
ch, err := getChainFromTable(n.conn, t, svc)
if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
}
if errors.Is(err, errorChainNotFound{t.Name, svc}) {
return nil // service chain does not exist, so neither does the portmapping rule
}
meta := svcPortMapRuleMeta(svc, targetIP, pm)
rule, err := n.findRuleByMetadata(t, ch, meta)
if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err)
}
if rule == nil {
return nil
}
if err := n.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting rule: %w", err)
}
return n.conn.Flush()
}
// DeleteSvc deletes the chains for the given service if any exist.
func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error {
for _, tip := range targetIPs {
table, err := n.getNFTByAddr(tip)
if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err)
}
t, err := getTableIfExists(n.conn, table.Proto, "nat")
if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err)
}
if t == nil {
return nil
}
ch, err := getChainFromTable(n.conn, t, svc)
if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
}
if errors.Is(err, errorChainNotFound{t.Name, svc}) {
return nil
}
n.conn.DelChain(ch)
}
return n.conn.Flush()
}
func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
var fam uint32
if targetIP.Is4() {
fam = unix.NFPROTO_IPV4
} else {
fam = unix.NFPROTO_IPV6
}
rule := &nftables.Rule{
Table: t,
Chain: ch,
UserData: meta,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte(tun),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{proto},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(matchPort),
},
&expr.Immediate{
Register: 1,
Data: targetIP.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: fam,
RegAddrMin: 1,
RegAddrMax: 1,
RegProtoMin: 2,
RegProtoMax: 2,
},
},
}
return rule
}
// svcPortMapRuleMeta generates metadata for a rule.
// This metadata can then be used to find the rule.
// https://github.com/google/nftables/issues/48
func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte {
return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol))
}
func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) {
if n.conn == nil || t == nil || ch == nil || len(meta) == 0 {
return nil, nil
}
rules, err := n.conn.GetRules(t, ch)
if err != nil {
return nil, fmt.Errorf("error listing rules: %w", err)
}
for _, rule := range rules {
if reflect.DeepEqual(rule.UserData, meta) {
return rule, nil
}
}
return nil, nil
}
func (n *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) {
polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(targetIP)
if err != nil {
return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err)
}
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
if err != nil {
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
}
svcCh, err := getOrCreateChain(n.conn, chainInfo{
table: nat,
name: svc,
chainType: nftables.ChainTypeNAT,
chainHook: nftables.ChainHookPrerouting,
chainPriority: nftables.ChainPriorityNATDest,
chainPolicy: &polAccept,
})
if err != nil {
return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
}
return nat, svcCh, nil
}
// // PortMap is the port mapping for a service rule.
type PortMap struct {
// MatchPort is the local port to which the rule should apply.
MatchPort uint16
// TargetPort is the port to which the traffic should be forwarded.
TargetPort uint16
// Protocol is the protocol to match packets on. Only TCP and UDP are
// supported.
Protocol string
}
func protoFromString(s string) (uint8, error) {
switch strings.ToLower(s) {
case "tcp":
return unix.IPPROTO_TCP, nil
case "udp":
return unix.IPPROTO_UDP, nil
default:
return 0, fmt.Errorf("unrecognized protocol: %q", s)
}
}