2006 lines
60 KiB
Go
2006 lines
60 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
//go:build linux
|
|
|
|
package linuxfw
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/google/nftables"
|
|
"github.com/google/nftables/expr"
|
|
"golang.org/x/sys/unix"
|
|
"tailscale.com/net/tsaddr"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/ptr"
|
|
)
|
|
|
|
const (
|
|
chainNameForward = "ts-forward"
|
|
chainNameInput = "ts-input"
|
|
chainNamePostrouting = "ts-postrouting"
|
|
)
|
|
|
|
// chainTypeRegular is an nftables chain that does not apply to a hook.
|
|
const chainTypeRegular = ""
|
|
|
|
type chainInfo struct {
|
|
table *nftables.Table
|
|
name string
|
|
chainType nftables.ChainType
|
|
chainHook *nftables.ChainHook
|
|
chainPriority *nftables.ChainPriority
|
|
chainPolicy *nftables.ChainPolicy
|
|
}
|
|
|
|
type nftable struct {
|
|
Proto nftables.TableFamily
|
|
Filter *nftables.Table
|
|
Nat *nftables.Table
|
|
}
|
|
|
|
// nftablesRunner implements a netfilterRunner using the netlink based nftables
|
|
// library. As nftables allows for arbitrary tables and chains, there is a need
|
|
// to follow conventions in order to integrate well with a surrounding
|
|
// ecosystem. The rules installed by nftablesRunner have the following
|
|
// properties:
|
|
// - Install rules that intend to take precedence over rules installed by
|
|
// other software. Tailscale provides packet filtering for tailnet traffic
|
|
// inside the daemon based on the tailnet ACL rules.
|
|
// - As nftables "accept" is not final, rules from high priority tables (low
|
|
// numbers) will fall through to lower priority tables (high numbers). In
|
|
// order to effectively be 'final', we install "jump" rules into conventional
|
|
// tables and chains that will reach an accept verdict inside those tables.
|
|
// - The table and chain conventions followed here are those used by
|
|
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
|
|
// negatively affect Tailscale function.
|
|
// - Be mindful that 1) all chains attached to a given hook (i.e the forward hook)
|
|
// will be processed in priority order till either a rule in one of the chains issues a drop verdict
|
|
// or there are no more chains for that hook
|
|
// 2) processing of individual rules within a chain will stop once one of them issues a final verdict (accept, drop).
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
|
|
type nftablesRunner struct {
|
|
conn *nftables.Conn
|
|
nft4 *nftable
|
|
nft6 *nftable
|
|
|
|
v6Available bool
|
|
v6NATAvailable bool
|
|
}
|
|
|
|
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(dst)
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
|
|
}
|
|
|
|
// ensure prerouting chain exists
|
|
preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: nat,
|
|
name: "PREROUTING",
|
|
chainType: nftables.ChainTypeNAT,
|
|
chainHook: nftables.ChainHookPrerouting,
|
|
chainPriority: nftables.ChainPriorityNATDest,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
|
|
}
|
|
return nat, preroutingCh, nil
|
|
}
|
|
|
|
func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var daddrOffset, fam, dadderLen uint32
|
|
if origDst.Is4() {
|
|
daddrOffset = 16
|
|
dadderLen = 4
|
|
fam = unix.NFPROTO_IPV4
|
|
} else {
|
|
daddrOffset = 24
|
|
dadderLen = 16
|
|
fam = unix.NFPROTO_IPV6
|
|
}
|
|
dnatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: preroutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: daddrOffset,
|
|
Len: dadderLen,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: origDst.AsSlice(),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeDestNAT,
|
|
Family: fam,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.InsertRule(dnatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// DNATWithLoadBalancer currently just forwards all traffic destined for origDst
|
|
// to the first IP address from the backend targets.
|
|
// TODO (irbekrm): instead of doing this load balance traffic evenly to all
|
|
// backend destinations.
|
|
// https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232
|
|
func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
|
|
return n.AddDNATRule(origDst, dsts[0])
|
|
}
|
|
|
|
func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var famConst uint32
|
|
if dst.Is4() {
|
|
famConst = unix.NFPROTO_IPV4
|
|
} else {
|
|
famConst = unix.NFPROTO_IPV6
|
|
}
|
|
|
|
dnatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: preroutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeDestNAT,
|
|
Family: famConst,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.InsertRule(dnatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(dst)
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring nat table exists: %w", err)
|
|
}
|
|
|
|
// ensure postrouting chain exists
|
|
postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: nat,
|
|
name: "POSTROUTING",
|
|
chainType: nftables.ChainTypeNAT,
|
|
chainHook: nftables.ChainHookPostrouting,
|
|
chainPriority: nftables.ChainPriorityNATSource,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
|
}
|
|
var daddrOffset, fam, daddrLen uint32
|
|
if dst.Is4() {
|
|
daddrOffset = 16
|
|
daddrLen = 4
|
|
fam = unix.NFPROTO_IPV4
|
|
} else {
|
|
daddrOffset = 24
|
|
daddrLen = 16
|
|
fam = unix.NFPROTO_IPV6
|
|
}
|
|
|
|
snatRule := &nftables.Rule{
|
|
Table: nat,
|
|
Chain: postRoutingCh,
|
|
Exprs: []expr.Any{
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: daddrOffset,
|
|
Len: daddrLen,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: dst.AsSlice(),
|
|
},
|
|
&expr.Immediate{
|
|
Register: 1,
|
|
Data: src.AsSlice(),
|
|
},
|
|
&expr.NAT{
|
|
Type: expr.NATTypeSourceNAT,
|
|
Family: fam,
|
|
RegAddrMin: 1,
|
|
},
|
|
},
|
|
}
|
|
n.conn.AddRule(snatRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set
|
|
// being forwarded via the given interface (tun) have MSS set to <MTU of the
|
|
// interface> - 40 (IP and TCP headers). This can be useful if this tailscale
|
|
// instance is expected to run as a forwarding proxy, forwarding packets from an
|
|
// endpoint with higher MTU in an environment where path MTU discovery is
|
|
// expected to not work (such as the proxies created by the Tailscale Kubernetes
|
|
// operator). ClamMSSToPMTU creates a new base-chain ts-clamp in the filter
|
|
// table with accept policy and priority -150. In practice, this means that for
|
|
// SYN packets the clamp rule in this chain will likely run first and accept the
|
|
// packet. This is fine because 1) nftables run ALL chains with the same hook
|
|
// type unless a rule in one of them drops the packet and 2) this chain does not
|
|
// have functionality to drop the packet- so in practice a matching clamp rule
|
|
// will always be followed by the custom tailscale filtering rules in the other
|
|
// chains attached to the filter hook (FORWARD, ts-forward).
|
|
// We do not want to place the clamping rule into FORWARD/ts-forward chains
|
|
// because wgengine populates those chains with rules that contain accept
|
|
// verdicts that would cause no further procesing within that chain. This
|
|
// functionality is currently invoked from outside wgengine (containerboot), so
|
|
// we don't want to race with wgengine for rule ordering within chains.
|
|
func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
table := n.getNFTByAddr(addr)
|
|
filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring filter table: %w", err)
|
|
}
|
|
|
|
// ensure ts-clamp chain exists
|
|
fwChain, err := getOrCreateChain(n.conn, chainInfo{
|
|
table: filterTable,
|
|
name: "ts-clamp",
|
|
chainType: nftables.ChainTypeFilter,
|
|
chainHook: nftables.ChainHookForward,
|
|
chainPriority: nftables.ChainPriorityMangle,
|
|
chainPolicy: &polAccept,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error ensuring forward chain: %w", err)
|
|
}
|
|
|
|
clampRule := &nftables.Rule{
|
|
Table: filterTable,
|
|
Chain: fwChain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tun),
|
|
},
|
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte{unix.IPPROTO_TCP},
|
|
},
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseTransportHeader,
|
|
Offset: 13,
|
|
Len: 1,
|
|
},
|
|
&expr.Bitwise{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Len: 1,
|
|
Mask: []byte{0x02},
|
|
Xor: []byte{0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq, // match any packet with a TCP flag set (SYN, ACK, RST)
|
|
Register: 1,
|
|
Data: []byte{0x00},
|
|
},
|
|
&expr.Rt{
|
|
Register: 1,
|
|
Key: expr.RtTCPMSS,
|
|
},
|
|
&expr.Byteorder{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Op: expr.ByteorderHton,
|
|
Len: 2,
|
|
Size: 2,
|
|
},
|
|
&expr.Exthdr{
|
|
SourceRegister: 1,
|
|
Type: 2,
|
|
Offset: 2,
|
|
Len: 2,
|
|
Op: expr.ExthdrOpTcpopt,
|
|
},
|
|
},
|
|
}
|
|
n.conn.AddRule(clampRule)
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// deleteTableIfExists deletes a nftables table via connection c if it exists
|
|
// within the given family.
|
|
func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
|
|
t, err := getTableIfExists(c, family, name)
|
|
if err != nil {
|
|
return fmt.Errorf("get table: %w", err)
|
|
}
|
|
if t == nil {
|
|
// Table does not exist, so nothing to delete.
|
|
return nil
|
|
}
|
|
c.DelTable(t)
|
|
if err := c.Flush(); err != nil {
|
|
if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
|
|
// Check if the table still exists. If it does not, then the error
|
|
// is due to the table not existing, so we can ignore it. Maybe a
|
|
// concurrent process deleted the table.
|
|
return nil
|
|
}
|
|
return fmt.Errorf("del table: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getTableIfExists returns the table with the given name from the given family
|
|
// if it exists. If none match, it returns (nil, nil).
|
|
func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
tables, err := c.ListTables()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get tables: %w", err)
|
|
}
|
|
for _, table := range tables {
|
|
if table.Name == name && table.Family == family {
|
|
return table, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// createTableIfNotExist creates a nftables table via connection c if it does
|
|
// not exist within the given family.
|
|
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
|
|
if t, err := getTableIfExists(c, family, name); err != nil {
|
|
return nil, fmt.Errorf("get table: %w", err)
|
|
} else if t != nil {
|
|
return t, nil
|
|
}
|
|
t := c.AddTable(&nftables.Table{
|
|
Family: family,
|
|
Name: name,
|
|
})
|
|
if err := c.Flush(); err != nil {
|
|
return nil, fmt.Errorf("add table: %w", err)
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
type errorChainNotFound struct {
|
|
chainName string
|
|
tableName string
|
|
}
|
|
|
|
func (e errorChainNotFound) Error() string {
|
|
return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName)
|
|
}
|
|
|
|
// getChainFromTable returns the chain with the given name from the given table.
|
|
// Note that a chain name is unique within a table.
|
|
func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) {
|
|
chains, err := c.ListChainsOfTableFamily(table.Family)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list chains: %w", err)
|
|
}
|
|
|
|
for _, chain := range chains {
|
|
// Table family is already checked so table name is unique
|
|
if chain.Table.Name == table.Name && chain.Name == name {
|
|
return chain, nil
|
|
}
|
|
}
|
|
|
|
return nil, errorChainNotFound{table.Name, name}
|
|
}
|
|
|
|
// isTSChain reports whether `name` begins with "ts-" (and is thus a
|
|
// Tailscale-managed chain).
|
|
func isTSChain(name string) bool {
|
|
return strings.HasPrefix(name, "ts-")
|
|
}
|
|
|
|
// createChainIfNotExist creates a chain with the given name in the given table
|
|
// if it does not exist.
|
|
func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
|
|
_, err := getOrCreateChain(c, cinfo)
|
|
return err
|
|
}
|
|
|
|
func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) {
|
|
chain, err := getChainFromTable(c, cinfo.table, cinfo.name)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) {
|
|
return nil, fmt.Errorf("get chain: %w", err)
|
|
} else if err == nil {
|
|
// The chain already exists. If it is a TS chain, check the
|
|
// type/hook/priority, but for "conventional chains" assume they're what
|
|
// we expect (in case iptables-nft/ufw make minor behavior changes in
|
|
// the future).
|
|
if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority) {
|
|
return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name)
|
|
}
|
|
return chain, nil
|
|
}
|
|
|
|
chain = c.AddChain(&nftables.Chain{
|
|
Name: cinfo.name,
|
|
Table: cinfo.table,
|
|
Type: cinfo.chainType,
|
|
Hooknum: cinfo.chainHook,
|
|
Priority: cinfo.chainPriority,
|
|
Policy: cinfo.chainPolicy,
|
|
})
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return nil, fmt.Errorf("add chain: %w", err)
|
|
}
|
|
|
|
return chain, nil
|
|
}
|
|
|
|
// NetfilterRunner abstracts helpers to run netfilter commands. It is
|
|
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
|
|
type NetfilterRunner interface {
|
|
// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
|
|
// is added only if it does not already exist.
|
|
AddLoopbackRule(addr netip.Addr) error
|
|
|
|
// DelLoopbackRule removes the rule added by AddLoopbackRule.
|
|
DelLoopbackRule(addr netip.Addr) error
|
|
|
|
// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
|
|
// "POSTROUTING" to jump from those chains to tailscale chains.
|
|
AddHooks() error
|
|
|
|
// DelHooks deletes rules added by AddHooks.
|
|
DelHooks(logf logger.Logf) error
|
|
|
|
// AddChains creates custom Tailscale chains.
|
|
AddChains() error
|
|
|
|
// DelChains removes chains added by AddChains.
|
|
DelChains() error
|
|
|
|
// AddBase adds rules reused by different other rules.
|
|
AddBase(tunname string) error
|
|
|
|
// DelBase removes rules added by AddBase.
|
|
DelBase() error
|
|
|
|
// AddSNATRule adds the netfilter rule to SNAT incoming traffic over
|
|
// the Tailscale interface destined for local subnets. An error is
|
|
// returned if the rule already exists.
|
|
AddSNATRule() error
|
|
|
|
// DelSNATRule removes the rule added by AddSNATRule.
|
|
DelSNATRule() error
|
|
|
|
// AddStatefulRule adds a netfilter rule for stateful packet filtering
|
|
// using conntrack.
|
|
AddStatefulRule(tunname string) error
|
|
|
|
// DelStatefulRule removes a netfilter rule for stateful packet filtering
|
|
// using conntrack.
|
|
DelStatefulRule(tunname string) error
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
HasIPV6() bool
|
|
|
|
// HasIPV6NAT reports true if the system supports IPv6 NAT.
|
|
HasIPV6NAT() bool
|
|
|
|
// HasIPV6Filter reports true if the system supports IPv6 filter tables
|
|
// This is only meaningful for iptables implementation, where hosts have
|
|
// partial ipables support (i.e missing filter table). For nftables
|
|
// implementation, this will default to the value of HasIPv6().
|
|
HasIPV6Filter() bool
|
|
|
|
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
|
|
// destined for the given original destination to the given new destination.
|
|
// This is used to forward all traffic destined for the Tailscale interface
|
|
// to the provided destination, as used in the Kubernetes ingress proxies.
|
|
AddDNATRule(origDst, dst netip.Addr) error
|
|
|
|
// DNATWithLoadBalancer adds a rule to the nat/PREROUTING chain to DNAT
|
|
// traffic destined for the given original destination to the given new
|
|
// destination(s) using round robin to load balance if more than one
|
|
// destination is provided. This is used to forward all traffic destined
|
|
// for the Tailscale interface to the provided destination(s), as used
|
|
// in the Kubernetes ingress proxies.
|
|
DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error
|
|
|
|
// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
|
|
// traffic destined for dst to src.
|
|
// This is used to forward traffic destined for the local machine over
|
|
// the Tailscale interface, as used in the Kubernetes egress proxies.
|
|
AddSNATRuleForDst(src, dst netip.Addr) error
|
|
|
|
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
|
|
// all traffic inbound from any interface except exemptInterface to dst.
|
|
// This is used to forward traffic destined for the local machine over
|
|
// the Tailscale interface, as used in the Kubernetes egress proxies.
|
|
DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error
|
|
|
|
// ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for
|
|
// traffic destined for the provided tun interface.
|
|
ClampMSSToPMTU(tun string, addr netip.Addr) error
|
|
|
|
// AddMagicsockPortRule adds a rule to the ts-input chain to accept
|
|
// incoming traffic on the specified port, to allow magicsock to
|
|
// communicate.
|
|
AddMagicsockPortRule(port uint16, network string) error
|
|
|
|
// DelMagicsockPortRule removes the rule created by AddMagicsockPortRule,
|
|
// if it exists.
|
|
DelMagicsockPortRule(port uint16, network string) error
|
|
}
|
|
|
|
// New creates a NetfilterRunner, auto-detecting whether to use
|
|
// nftables or iptables.
|
|
// As nftables is still experimental, iptables will be used unless
|
|
// either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint
|
|
// parameter, is set to one of "nftables" or "auto".
|
|
func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) {
|
|
mode := detectFirewallMode(logf, prefHint)
|
|
switch mode {
|
|
case FirewallModeIPTables:
|
|
return newIPTablesRunner(logf)
|
|
case FirewallModeNfTables:
|
|
return newNfTablesRunner(logf)
|
|
default:
|
|
return nil, fmt.Errorf("unknown firewall mode %v", mode)
|
|
}
|
|
}
|
|
|
|
// newNfTablesRunner creates a new nftablesRunner without guaranteeing
|
|
// the existence of the tables and chains.
|
|
func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
|
|
conn, err := nftables.New()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("nftables connection: %w", err)
|
|
}
|
|
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
|
|
|
|
v6err := CheckIPv6(logf)
|
|
if v6err != nil {
|
|
logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
|
|
}
|
|
supportsV6 := v6err == nil
|
|
var nft6 *nftable
|
|
|
|
if supportsV6 {
|
|
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
|
|
logf("v6nat availability: true")
|
|
}
|
|
|
|
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
|
|
|
|
return &nftablesRunner{
|
|
conn: conn,
|
|
nft4: nft4,
|
|
nft6: nft6,
|
|
v6Available: supportsV6,
|
|
}, nil
|
|
}
|
|
|
|
// newLoadSaddrExpr creates a new nftables expression that loads the source
|
|
// address of the packet into the given register.
|
|
func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) {
|
|
switch proto {
|
|
case nftables.TableFamilyIPv4:
|
|
return &expr.Payload{
|
|
DestRegister: destReg,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: 12,
|
|
Len: 4,
|
|
}, nil
|
|
case nftables.TableFamilyIPv6:
|
|
return &expr.Payload{
|
|
DestRegister: destReg,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: 8,
|
|
Len: 16,
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto)
|
|
}
|
|
}
|
|
|
|
// newLoadDportExpr creates a new nftables express that loads the desination port
|
|
// of a TCP/UDP packet into the given register.
|
|
func newLoadDportExpr(destReg uint32) expr.Any {
|
|
return &expr.Payload{
|
|
DestRegister: destReg,
|
|
Base: expr.PayloadBaseTransportHeader,
|
|
Offset: 2,
|
|
Len: 2,
|
|
}
|
|
}
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
func (n *nftablesRunner) HasIPV6() bool {
|
|
return n.v6Available
|
|
}
|
|
|
|
// HasIPV6NAT returns true if the system supports IPv6.
|
|
// Kernel support for nftables was added after support for IPv6
|
|
// NAT, so no need for a separate IPv6 NAT support check like we do for iptables.
|
|
// https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources
|
|
func (n *nftablesRunner) HasIPV6NAT() bool {
|
|
return n.v6Available
|
|
}
|
|
|
|
// HasIPV6Filter returns true if system supports IPv6. There are no known edge
|
|
// cases where nftables running on a host that supports IPv6 would not support
|
|
// filter table.
|
|
func (n *nftablesRunner) HasIPV6Filter() bool {
|
|
return n.v6Available
|
|
}
|
|
|
|
// findRule iterates through the rules to find the rule with matching expressions.
|
|
func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) {
|
|
rules, err := conn.GetRules(rule.Table, rule.Chain)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get nftables rules: %w", err)
|
|
}
|
|
if len(rules) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ruleLoop:
|
|
for _, r := range rules {
|
|
if len(r.Exprs) != len(rule.Exprs) {
|
|
continue
|
|
}
|
|
|
|
for i, e := range r.Exprs {
|
|
// Skip counter expressions, as they will not match.
|
|
if _, ok := e.(*expr.Counter); ok {
|
|
continue
|
|
}
|
|
if !reflect.DeepEqual(e, rule.Exprs[i]) {
|
|
continue ruleLoop
|
|
}
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func createLoopbackRule(
|
|
proto nftables.TableFamily,
|
|
table *nftables.Table,
|
|
chain *nftables.Chain,
|
|
addr netip.Addr,
|
|
) (*nftables.Rule, error) {
|
|
saddrExpr, err := newLoadSaddrExpr(proto, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
|
}
|
|
loopBackRule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{
|
|
Key: expr.MetaKeyIIFNAME,
|
|
Register: 1,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte("lo"),
|
|
},
|
|
saddrExpr,
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: addr.AsSlice(),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
return loopBackRule, nil
|
|
}
|
|
|
|
// insertLoopbackRule inserts the TS loop back rule into
|
|
// the given chain as the first rule if it does not exist.
|
|
func insertLoopbackRule(
|
|
conn *nftables.Conn, proto nftables.TableFamily,
|
|
table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error {
|
|
|
|
loopBackRule, err := createLoopbackRule(proto, table, chain, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
}
|
|
|
|
// If TestDial is set, we are running in test mode and we should not
|
|
// find rule because header will mismatch.
|
|
if conn.TestDial == nil {
|
|
// Check if the rule already exists.
|
|
rule, err := findRule(conn, loopBackRule)
|
|
if err != nil {
|
|
return fmt.Errorf("find rule: %w", err)
|
|
}
|
|
if rule != nil {
|
|
// Rule already exists, no need to insert.
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// This inserts the rule to the top of the chain
|
|
_ = conn.InsertRule(loopBackRule)
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("insert rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getNFTByAddr returns the nftables with correct IP family
|
|
// that we will be using for the given address.
|
|
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable {
|
|
if addr.Is6() {
|
|
return n.nft6
|
|
}
|
|
return n.nft4
|
|
}
|
|
|
|
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
|
|
// a local Tailscale IP. This rule is added only if it does not already exist.
|
|
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
|
nf := n.getNFTByAddr(addr)
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
}
|
|
|
|
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
|
|
return fmt.Errorf("add loopback rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelLoopbackRule removes the nftables rule permitting loopback
|
|
// traffic to a Tailscale IP.
|
|
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
nf := n.getNFTByAddr(addr)
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
}
|
|
|
|
loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
}
|
|
|
|
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
|
|
if err != nil {
|
|
return fmt.Errorf("find loop back rule: %w", err)
|
|
}
|
|
if existingLoopBackRule == nil {
|
|
// Rule does not exist, no need to delete.
|
|
return nil
|
|
}
|
|
|
|
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
|
|
return fmt.Errorf("delete rule: %w", err)
|
|
}
|
|
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// getTables gets the available nftable in nftables runner.
|
|
func (n *nftablesRunner) getTables() []*nftable {
|
|
if n.v6Available {
|
|
return []*nftable{n.nft4, n.nft6}
|
|
}
|
|
return []*nftable{n.nft4}
|
|
}
|
|
|
|
// getNATTables gets the available nftable in nftables runner.
|
|
// If the system does not support IPv6 NAT, only the IPv4 nftable
|
|
// will be returned.
|
|
func (n *nftablesRunner) getNATTables() []*nftable {
|
|
if n.v6NATAvailable {
|
|
return n.getTables()
|
|
}
|
|
return []*nftable{n.nft4}
|
|
}
|
|
|
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
|
// if the ts-chain doesn't already exist.
|
|
func (n *nftablesRunner) AddChains() error {
|
|
polAccept := nftables.ChainPolicyAccept
|
|
for _, table := range n.getTables() {
|
|
// Create the filter table if it doesn't exist, this table name is the same
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
// chains are conclusive.
|
|
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
if err != nil {
|
|
return fmt.Errorf("create table: %w", err)
|
|
}
|
|
table.Filter = filter
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
}
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
}
|
|
// Adding the tailscale chains that contain our rules.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
}
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
// Create the nat table if it doesn't exist, this table name is the same
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
// chains are conclusive.
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
if err != nil {
|
|
return fmt.Errorf("create table: %w", err)
|
|
}
|
|
table.Nat = nat
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
}
|
|
// Adding the tailscale chain that contains our rules.
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
}
|
|
}
|
|
|
|
return n.conn.Flush()
|
|
}
|
|
|
|
// These are dummy chains and tables we create to detect if nftables is
|
|
// available. We create them, then delete them. If we can create and delete
|
|
// them, then we can use nftables. If we can't, then we assume that we're
|
|
// running on a system that doesn't support nftables. See
|
|
// createDummyPostroutingChains.
|
|
const (
|
|
tsDummyChainName = "ts-test-postrouting"
|
|
tsDummyTableName = "ts-test-nat"
|
|
)
|
|
|
|
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
|
|
// via netfilter via nftables, as a last resort measure to detect that nftables
|
|
// can be used. It cleans up the dummy chains after creation.
|
|
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
|
|
polAccept := ptr.To(nftables.ChainPolicyAccept)
|
|
for _, table := range n.getNATTables() {
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
|
|
if err != nil {
|
|
return fmt.Errorf("create nat table: %w", err)
|
|
}
|
|
defer func(fm nftables.TableFamily) {
|
|
if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil {
|
|
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
|
|
}
|
|
}(table.Proto)
|
|
|
|
table.Nat = nat
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
|
|
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
|
|
}
|
|
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
|
|
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// deleteChainIfExists deletes a chain if it exists.
|
|
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
|
|
chain, err := getChainFromTable(c, table, name)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) {
|
|
return fmt.Errorf("get chain: %w", err)
|
|
} else if err != nil {
|
|
// If the chain doesn't exist, we don't need to delete it.
|
|
return nil
|
|
}
|
|
|
|
c.FlushChain(chain)
|
|
c.DelChain(chain)
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return fmt.Errorf("flush and delete chain: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelChains removes the custom Tailscale chains from netfilter via nftables.
|
|
func (n *nftablesRunner) DelChains() error {
|
|
for _, table := range n.getTables() {
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
|
|
if n.v6NATAvailable {
|
|
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := n.conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
|
|
func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule {
|
|
exprs := []expr.Any{
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictJump,
|
|
Chain: toChainName,
|
|
},
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: fromChain,
|
|
Exprs: exprs,
|
|
}
|
|
|
|
return rule
|
|
}
|
|
|
|
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
|
|
func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
|
rule := createHookRule(table, fromChain, toChainName)
|
|
_ = conn.InsertRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
|
// in tables and jump from those chains to tailscale chains.
|
|
func (n *nftablesRunner) AddHooks() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Filter, inputChain, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
|
if err != nil {
|
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("Addhook: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
|
|
func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error {
|
|
rule := createHookRule(table, fromChain, toChainName)
|
|
existingRule, err := findRule(conn, rule)
|
|
if err != nil {
|
|
return fmt.Errorf("Failed to find hook rule: %w", err)
|
|
}
|
|
|
|
if existingRule == nil {
|
|
return nil
|
|
}
|
|
|
|
_ = conn.DelRule(existingRule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del hook rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
|
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Filter, inputChain, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD")
|
|
if err != nil {
|
|
return fmt.Errorf("get FORWARD chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Filter, forwardChain, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
|
|
if err != nil {
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
}
|
|
err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("delhook: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// maskof returns the mask of the given prefix in big endian bytes.
|
|
func maskof(pfx netip.Prefix) []byte {
|
|
mask := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits()))
|
|
return mask
|
|
}
|
|
|
|
// createRangeRule creates a rule that matches packets with source IP from the give
|
|
// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname,
|
|
// and makes the given decision. Only IPv4 is supported.
|
|
func createRangeRule(
|
|
table *nftables.Table, chain *nftables.Chain,
|
|
tunname string, rng netip.Prefix, decision expr.VerdictKind,
|
|
) (*nftables.Rule, error) {
|
|
if rng.Addr().Is6() {
|
|
return nil, errors.New("IPv6 is not supported")
|
|
}
|
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %w", err)
|
|
}
|
|
netip := rng.Addr().AsSlice()
|
|
mask := maskof(rng)
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
saddrExpr,
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: mask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: netip,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: decision,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
|
|
}
|
|
|
|
// addReturnChromeOSVMRangeRule adds a rule to return if the source IP
|
|
// is in the ChromeOS VM range.
|
|
func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
if err = c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addDropCGNATRangeRule adds a rule to drop if the source IP is in the
|
|
// CGNAT range.
|
|
func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
if err = c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createSetSubnetRouteMarkRule creates a rule to set the subnet route
|
|
// mark if the packet is from the given interface.
|
|
func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
|
hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTsFwmarkMaskNeg,
|
|
Xor: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Meta{
|
|
Key: expr.MetaKeyMARK,
|
|
SourceRegister: true,
|
|
Register: 1,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark
|
|
// if the packet is from the given interface.
|
|
func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createSetSubnetRouteMarkRule(table, chain, tunname)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = c.AddRule(rule)
|
|
|
|
if err := c.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop
|
|
// outgoing packets from the CGNAT range.
|
|
func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) {
|
|
_, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse cidr: %v", err)
|
|
}
|
|
mask, err := hex.DecodeString(ipNet.Mask.String())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode mask: %v", err)
|
|
}
|
|
netip := ipNet.IP.Mask(ipNet.Mask).To4()
|
|
saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("newLoadSaddrExpr: %v", err)
|
|
}
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
saddrExpr,
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: mask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: netip,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictDrop,
|
|
},
|
|
},
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop
|
|
// outgoing packets from the CGNAT range.
|
|
func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname)
|
|
if err != nil {
|
|
return fmt.Errorf("create rule: %w", err)
|
|
}
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("add rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets
|
|
// from the given interface.
|
|
func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
|
|
return &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets
|
|
// from the given interface.
|
|
func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule := createAcceptOutgoingPacketRule(table, chain, tunname)
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createAcceptOnPortRule creates a rule to accept incoming packets to
|
|
// a given destination UDP port.
|
|
func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule {
|
|
portBytes := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(portBytes, port)
|
|
return &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{
|
|
Key: expr.MetaKeyL4PROTO,
|
|
Register: 1,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte{unix.IPPROTO_UDP},
|
|
},
|
|
newLoadDportExpr(1),
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: portBytes,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// addAcceptOnPortRule adds a rule to accept incoming packets to
|
|
// a given destination UDP port.
|
|
func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
|
|
rule := createAcceptOnPortRule(table, chain, port)
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addAcceptOnPortRule removes a rule to accept incoming packets to
|
|
// a given destination UDP port.
|
|
func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error {
|
|
rule := createAcceptOnPortRule(table, chain, port)
|
|
rule, err := findRule(conn, rule)
|
|
if err != nil {
|
|
return fmt.Errorf("find rule: %v", err)
|
|
}
|
|
|
|
_ = conn.DelRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on
|
|
// the specified UDP port, so magicsock can accept incoming connections.
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
// is added for IPv4 or IPv6.
|
|
func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
|
|
var filterTable *nftables.Table
|
|
switch network {
|
|
case "udp4":
|
|
filterTable = n.nft4.Filter
|
|
case "udp6":
|
|
filterTable = n.nft6.Filter
|
|
default:
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
}
|
|
|
|
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
}
|
|
|
|
err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port)
|
|
if err != nil {
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept
|
|
// incoming traffic on a particular UDP port.
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
// is removed for IPv4 or IPv6.
|
|
func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
|
|
var filterTable *nftables.Table
|
|
switch network {
|
|
case "udp4":
|
|
filterTable = n.nft4.Filter
|
|
case "udp6":
|
|
filterTable = n.nft6.Filter
|
|
default:
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
}
|
|
|
|
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
}
|
|
|
|
err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port)
|
|
if err != nil {
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createAcceptIncomingPacketRule creates a rule to accept incoming packets to
|
|
// the given interface.
|
|
func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule {
|
|
return &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error {
|
|
rule := createAcceptIncomingPacketRule(table, chain, tunname)
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddBase adds some basic processing rules.
|
|
func (n *nftablesRunner) AddBase(tunname string) error {
|
|
if err := n.addBase4(tunname); err != nil {
|
|
return fmt.Errorf("add base v4: %w", err)
|
|
}
|
|
if n.HasIPV6() {
|
|
if err := n.addBase6(tunname); err != nil {
|
|
return fmt.Errorf("add base v6: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// addBase4 adds some basic IPv4 processing rules.
|
|
func (n *nftablesRunner) addBase4(tunname string) error {
|
|
conn := n.conn
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
}
|
|
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
|
|
}
|
|
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
|
}
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
|
|
}
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain v4: %v", err)
|
|
}
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
}
|
|
|
|
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
|
|
}
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
|
|
}
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush base v4: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addBase6 adds some basic IPv6 processing rules.
|
|
func (n *nftablesRunner) addBase6(tunname string) error {
|
|
conn := n.conn
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
}
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
|
|
}
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain v6: %w", err)
|
|
}
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
|
|
}
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
|
|
}
|
|
|
|
if err = conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush base v6: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelBase empties, but does not remove, custom Tailscale chains from
|
|
// netfilter via iptables.
|
|
func (n *nftablesRunner) DelBase() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getTables() {
|
|
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
|
if err != nil {
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
}
|
|
conn.FlushChain(inputChain)
|
|
forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain: %v", err)
|
|
}
|
|
conn.FlushChain(forwardChain)
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %v", err)
|
|
}
|
|
conn.FlushChain(postrouteChain)
|
|
}
|
|
|
|
return conn.Flush()
|
|
}
|
|
|
|
// createMatchSubnetRouteMarkRule creates a rule that matches packets
|
|
// with the subnet route mark and takes the specified action.
|
|
func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) {
|
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
var endAction expr.Any
|
|
endAction = &expr.Verdict{Kind: expr.VerdictAccept}
|
|
if action == Masq {
|
|
endAction = &expr.Masq{}
|
|
}
|
|
|
|
exprs := []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTSFwmarkMask,
|
|
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Counter{},
|
|
endAction,
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// addMatchSubnetRouteMarkRule adds a rule that matches packets with
|
|
// the subnet route mark and takes the specified action.
|
|
func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error {
|
|
rule, err := createMatchSubnetRouteMarkRule(table, chain, action)
|
|
if err != nil {
|
|
return fmt.Errorf("create match subnet route mark rule: %w", err)
|
|
}
|
|
_ = conn.AddRule(rule)
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
|
|
// local subnets.
|
|
func (n *nftablesRunner) AddSNATRule() error {
|
|
conn := n.conn
|
|
|
|
for _, table := range n.getNATTables() {
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
}
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil {
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add SNAT rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
|
|
// local subnets. An error is returned if the rule does not exist.
|
|
func (n *nftablesRunner) DelSNATRule() error {
|
|
conn := n.conn
|
|
|
|
hexTSFwmarkMask := getTailscaleFwmarkMask()
|
|
hexTSSubnetRouteMark := getTailscaleSubnetRouteMark()
|
|
|
|
exprs := []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyMARK, Register: 1},
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: hexTSFwmarkMask,
|
|
},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: hexTSSubnetRouteMark,
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Masq{},
|
|
}
|
|
|
|
for _, table := range n.getNATTables() {
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
if err != nil {
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
}
|
|
|
|
rule := &nftables.Rule{
|
|
Table: table.Nat,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
}
|
|
|
|
SNATRule, err := findRule(conn, rule)
|
|
if err != nil {
|
|
return fmt.Errorf("find SNAT rule v4: %w", err)
|
|
}
|
|
|
|
if SNATRule != nil {
|
|
_ = conn.DelRule(SNATRule)
|
|
}
|
|
}
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del SNAT rule: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func nativeUint32(v uint32) []byte {
|
|
b := make([]byte, 4)
|
|
binary.NativeEndian.PutUint32(b, v)
|
|
return b
|
|
}
|
|
|
|
func makeStatefulRuleExprs(tunname string) []expr.Any {
|
|
return []expr.Any{
|
|
// Check if the output interface is the Tailscale interface by
|
|
// first loding the OIFNAME into register 1 and comparing it
|
|
// against our tunname.
|
|
//
|
|
// 'cmp' implicitly breaks from a rule if a comparison fails,
|
|
// so if we continue past this rule we know that the packet is
|
|
// going to our TUN.
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte(tunname),
|
|
},
|
|
|
|
// Store the conntrack state in register 1
|
|
&expr.Ct{
|
|
Register: 1,
|
|
Key: expr.CtKeySTATE,
|
|
},
|
|
// Mask the state in register 1 to "hide" the ESTABLISHED and
|
|
// RELATED bits (which are expected and fine); if there are any
|
|
// other bits, we want them to remain.
|
|
//
|
|
// This operation is, in the kernel:
|
|
// dst[i] = (src[i] & mask[i]) ^ xor[i]
|
|
//
|
|
// So, we can mask by setting the inverse of the bits we want
|
|
// to remove; i.e. ESTABLISHED = 0b00000010, RELATED =
|
|
// 0b00000100, so, if we assume an 8-bit state (in reality,
|
|
// it's 32-bit), we can mask with 0b11111001 to clear those
|
|
// bits and keep everything else (e.g. the INVALID bit which is
|
|
// 0b00000001).
|
|
//
|
|
// TODO(andrew-d): for now, let's also allow
|
|
// CtStateBitUNTRACKED, which is a state for packets that are not
|
|
// tracked (marked so explicitly with an iptables rule using
|
|
// --notrack); we should figure out if we want to allow this or not.
|
|
&expr.Bitwise{
|
|
SourceRegister: 1,
|
|
DestRegister: 1,
|
|
Len: 4,
|
|
Mask: nativeUint32(^(0 |
|
|
expr.CtStateBitESTABLISHED |
|
|
expr.CtStateBitRELATED |
|
|
expr.CtStateBitUNTRACKED)),
|
|
|
|
// Xor is unused but must be specified
|
|
Xor: nativeUint32(0),
|
|
},
|
|
// Compare against the expected state (0, i.e. no bits set
|
|
// other than maybe ESTABLISHED and RELATED). We want this
|
|
// comparison to fail if there are no bits set, so that this
|
|
// rule's evaluation stops and we don't fall through to the
|
|
// "Drop" verdict.
|
|
//
|
|
// For example, if the state is ESTABLISHED (and we want to
|
|
// break from this rule/accept this packet):
|
|
// state = ESTABLISHED
|
|
// register1 = 0b0 (since the bitwise operation cleared the ESTABLISHED bit)
|
|
//
|
|
// compare register1 (0b0) != 0: false
|
|
// -> comparison implicitly breaks
|
|
// -> continue to the next rule
|
|
//
|
|
// For example, if the state is NEW (and we want to continue to
|
|
// the next expression and thus drop this packet):
|
|
// state = NEW
|
|
// register1 = 0b1000
|
|
//
|
|
// compare register1 (0b1000) != 0: true
|
|
// -> comparison continues to next expr
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpNeq,
|
|
Register: 1,
|
|
Data: []byte{0, 0, 0, 0},
|
|
},
|
|
// If we get here, we know that this packet is going to our TUN
|
|
// device, and has a conntrack state set other than ESTABLISHED
|
|
// or RELATED. We thus count and drop the packet.
|
|
&expr.Counter{},
|
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
}
|
|
|
|
// TODO(andrew-d): iptables-nft writes a rule that dumps as:
|
|
//
|
|
// match name conntrack rev 3
|
|
//
|
|
// I think this is using expr.Match against the following struct
|
|
// (xt_conntrack_mtinfo3):
|
|
//
|
|
// https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/xt_conntrack.h#L64-L77
|
|
//
|
|
// We could probably do something similar here, but I'm not sure if
|
|
// there's any advantage. Below is an example Match statement if we
|
|
// decide to do that, based on dumping the rule that iptables-nft
|
|
// generates:
|
|
//
|
|
// _ = expr.Match{
|
|
// Name: "conntrack",
|
|
// Rev: 3,
|
|
// Info: &xt.ConntrackMtinfo3{
|
|
// ConntrackMtinfo2: xt.ConntrackMtinfo2{
|
|
// ConntrackMtinfoBase: xt.ConntrackMtinfoBase{
|
|
// MatchFlags: xt.ConntrackState,
|
|
// InvertFlags: xt.ConntrackState,
|
|
// },
|
|
// // Mask the state to remove ESTABLISHED and
|
|
// // RELATED before comparing.
|
|
// StateMask: expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED,
|
|
// },
|
|
// },
|
|
// }
|
|
}
|
|
|
|
// AddStatefulRule adds a netfilter rule for stateful packet filtering using
|
|
// conntrack.
|
|
func (n *nftablesRunner) AddStatefulRule(tunname string) error {
|
|
conn := n.conn
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
for _, table := range n.getTables() {
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
}
|
|
|
|
// First, find the 'accept' rule that we want to insert our rule before.
|
|
acceptRule := createAcceptOutgoingPacketRule(table.Filter, chain, tunname)
|
|
rule, err := findRule(conn, acceptRule)
|
|
if err != nil {
|
|
return fmt.Errorf("find accept rule: %w", err)
|
|
}
|
|
|
|
conn.InsertRule(&nftables.Rule{
|
|
Table: table.Filter,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
|
|
// Specifying Position in an Insert operation means to
|
|
// insert this rule before the specified rule.
|
|
Position: rule.Handle,
|
|
})
|
|
}
|
|
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush add stateful rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DelStatefulRule removes the netfilter rule for stateful packet filtering
|
|
// using conntrack.
|
|
func (n *nftablesRunner) DelStatefulRule(tunname string) error {
|
|
conn := n.conn
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
for _, table := range n.getTables() {
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
if err != nil {
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
}
|
|
rule, err := findRule(conn, &nftables.Rule{
|
|
Table: table.Nat,
|
|
Chain: chain,
|
|
Exprs: exprs,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("find stateful rule: %w", err)
|
|
}
|
|
|
|
if rule != nil {
|
|
conn.DelRule(rule)
|
|
}
|
|
}
|
|
if err := conn.Flush(); err != nil {
|
|
return fmt.Errorf("flush del stateful rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
|
|
// the entire chain tsChainName. Errors are logged, but attempts to remove both
|
|
// the jump rule and chain continue even if one errors.
|
|
func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) {
|
|
// remove the jump first, before removing the jump destination.
|
|
defaultChain, err := getChainFromTable(conn, table, hookChainName)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
|
logf("cleanup: did not find default chain: %s", err)
|
|
}
|
|
if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) {
|
|
// delete hook in convention chain
|
|
_ = delHookRule(conn, table, defaultChain, tsChainName)
|
|
}
|
|
|
|
tsChain, err := getChainFromTable(conn, table, tsChainName)
|
|
if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) {
|
|
logf("cleanup: did not find ts-chain: %s", err)
|
|
}
|
|
|
|
if tsChain != nil {
|
|
// flush and delete ts-chain
|
|
conn.FlushChain(tsChain)
|
|
conn.DelChain(tsChain)
|
|
err = conn.Flush()
|
|
logf("cleanup: delete and flush chain %s: %s", tsChainName, err)
|
|
}
|
|
}
|
|
|
|
// NfTablesCleanUp removes all Tailscale added nftables rules.
|
|
// Any errors that occur are logged to the provided logf.
|
|
func NfTablesCleanUp(logf logger.Logf) {
|
|
conn, err := nftables.New()
|
|
if err != nil {
|
|
logf("cleanup: nftables connection: %s", err)
|
|
}
|
|
|
|
tables, err := conn.ListTables() // both v4 and v6
|
|
if err != nil {
|
|
logf("cleanup: list tables: %s", err)
|
|
}
|
|
|
|
for _, table := range tables {
|
|
// These table names were used briefly in 1.48.0.
|
|
if table.Name == "ts-filter" || table.Name == "ts-nat" {
|
|
conn.DelTable(table)
|
|
if err := conn.Flush(); err != nil {
|
|
logf("cleanup: flush delete table %s: %s", table.Name, err)
|
|
}
|
|
}
|
|
|
|
if table.Name == "filter" {
|
|
cleanupChain(logf, conn, table, "INPUT", chainNameInput)
|
|
cleanupChain(logf, conn, table, "FORWARD", chainNameForward)
|
|
}
|
|
if table.Name == "nat" {
|
|
cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting)
|
|
}
|
|
}
|
|
}
|