tailscale/util/linuxfw/iptables.go

826 lines
18 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !(386 || loong64 || arm || armbe)
package linuxfw
import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"unsafe"
"github.com/josharian/native"
"golang.org/x/sys/unix"
linuxabi "gvisor.dev/gvisor/pkg/abi/linux"
"tailscale.com/net/netaddr"
"tailscale.com/types/logger"
)
type sockLen uint32
var (
iptablesChainNames = map[int]string{
linuxabi.NF_INET_PRE_ROUTING: "PREROUTING",
linuxabi.NF_INET_LOCAL_IN: "INPUT",
linuxabi.NF_INET_FORWARD: "FORWARD",
linuxabi.NF_INET_LOCAL_OUT: "OUTPUT",
linuxabi.NF_INET_POST_ROUTING: "POSTROUTING",
}
iptablesStandardChains = (func() map[string]bool {
ret := make(map[string]bool)
for _, v := range iptablesChainNames {
ret[v] = true
}
return ret
})()
)
// DebugNetfilter prints debug information about iptables rules to the
// provided log function.
func DebugIptables(logf logger.Logf) error {
for _, table := range []string{"filter", "nat", "raw"} {
type chainAndEntry struct {
chain string
entry *entry
}
// Collect all entries first so we can resolve jumps
var (
lastChain string
ces []chainAndEntry
chainOffsets = make(map[int]string)
)
err := enumerateIptablesTable(logf, table, func(chain string, entry *entry) error {
if chain != lastChain {
chainOffsets[entry.Offset] = chain
lastChain = chain
}
ces = append(ces, chainAndEntry{
chain: lastChain,
entry: entry,
})
return nil
})
if err != nil {
return err
}
lastChain = ""
for _, ce := range ces {
if ce.chain != lastChain {
logf("iptables: table=%s chain=%s", table, ce.chain)
lastChain = ce.chain
}
// Fixup jump
if std, ok := ce.entry.Target.Data.(standardTarget); ok {
if strings.HasPrefix(std.Verdict, "JUMP(") {
var off int
if _, err := fmt.Sscanf(std.Verdict, "JUMP(%d)", &off); err == nil {
if jt, ok := chainOffsets[off]; ok {
std.Verdict = "JUMP(" + jt + ")"
ce.entry.Target.Data = std
}
}
}
}
logf("iptables: entry=%+v", ce.entry)
}
}
return nil
}
// DetectIptables returns the number of iptables rules that are present in the
// system, ignoring the default "ACCEPT" rule present in the standard iptables
// chains.
//
// It only returns an error when the kernel returns an error (i.e. when a
// syscall fails); when there are no iptables rules, it is valid for this
// function to return 0, nil.
func DetectIptables() (int, error) {
dummyLog := func(string, ...any) {}
var (
validRules int
firstErr error
)
for _, table := range []string{"filter", "nat", "raw"} {
err := enumerateIptablesTable(dummyLog, table, func(chain string, entry *entry) error {
// If we have any rules other than basic 'ACCEPT' entries in a
// standard chain, then we consider this a valid rule.
switch {
case !iptablesStandardChains[chain]:
validRules++
case entry.Target.Name != "standard":
validRules++
case entry.Target.Name == "standard" && entry.Target.Data.(standardTarget).Verdict != "ACCEPT":
validRules++
}
return nil
})
if err != nil && firstErr == nil {
firstErr = err
}
}
return validRules, firstErr
}
func enumerateIptablesTable(logf logger.Logf, table string, cb func(string, *entry) error) error {
ln, err := net.Listen("tcp4", ":0")
if err != nil {
return err
}
defer ln.Close()
tcpLn := ln.(*net.TCPListener)
conn, err := tcpLn.SyscallConn()
if err != nil {
return err
}
var tableName linuxabi.TableName
copy(tableName[:], []byte(table))
tbl := linuxabi.IPTGetinfo{
Name: tableName,
}
slt := sockLen(linuxabi.SizeOfIPTGetinfo)
var ctrlErr error
err = conn.Control(func(fd uintptr) {
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
uintptr(unix.SOL_IP),
linuxabi.IPT_SO_GET_INFO,
uintptr(unsafe.Pointer(&tbl)),
uintptr(unsafe.Pointer(&slt)),
0,
)
if errno != 0 {
ctrlErr = errno
return
}
})
if err != nil {
return err
}
if ctrlErr != nil {
return ctrlErr
}
if tbl.Size < 1 {
return nil
}
// Allocate enough space to be able to get all iptables information.
entsBuf := make([]byte, linuxabi.SizeOfIPTGetEntries+tbl.Size)
entsHdr := (*linuxabi.IPTGetEntries)(unsafe.Pointer(&entsBuf[0]))
entsHdr.Name = tableName
entsHdr.Size = tbl.Size
slt = sockLen(len(entsBuf))
err = conn.Control(func(fd uintptr) {
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
uintptr(unix.SOL_IP),
linuxabi.IPT_SO_GET_ENTRIES,
uintptr(unsafe.Pointer(&entsBuf[0])),
uintptr(unsafe.Pointer(&slt)),
0,
)
if errno != 0 {
ctrlErr = errno
return
}
})
if err != nil {
return err
}
if ctrlErr != nil {
return ctrlErr
}
// Skip header
entsBuf = entsBuf[linuxabi.SizeOfIPTGetEntries:]
var (
totalOffset int
currentChain string
)
for len(entsBuf) > 0 {
parser := entryParser{
buf: entsBuf,
logf: logf,
checkExtraBytes: true,
}
entry, err := parser.parseEntry(entsBuf)
if err != nil {
logf("iptables: err=%v", err)
break
}
entry.Offset += totalOffset
// Don't pass 'ERROR' nodes to our caller
if entry.Target.Name == "ERROR" {
if parser.offset == len(entsBuf) {
// all done
break
}
// New user-defined chain
currentChain = entry.Target.Data.(errorTarget).ErrorName
} else {
// Detect if we're at a new chain based on the hook
// offsets we fetched earlier.
for i, he := range tbl.HookEntry {
if int(he) == totalOffset {
currentChain = iptablesChainNames[i]
}
}
// Now that we have everything, call our callback.
if err := cb(currentChain, &entry); err != nil {
return err
}
}
entsBuf = entsBuf[parser.offset:]
totalOffset += parser.offset
}
return nil
}
// TODO(andrew): convert to use cstruct
type entryParser struct {
buf []byte
offset int
logf logger.Logf
// Set to 'true' to print debug messages about unused bytes returned
// from the kernel
checkExtraBytes bool
}
func (p *entryParser) haveLen(ln int) bool {
if len(p.buf)-p.offset < ln {
return false
}
return true
}
func (p *entryParser) assertLen(ln int) error {
if !p.haveLen(ln) {
return fmt.Errorf("need %d bytes: %w", ln, errBufferTooSmall)
}
return nil
}
func (p *entryParser) getBytes(amt int) []byte {
ret := p.buf[p.offset : p.offset+amt]
p.offset += amt
return ret
}
func (p *entryParser) getByte() byte {
ret := p.buf[p.offset]
p.offset += 1
return ret
}
func (p *entryParser) get4() (ret [4]byte) {
ret[0] = p.buf[p.offset+0]
ret[1] = p.buf[p.offset+1]
ret[2] = p.buf[p.offset+2]
ret[3] = p.buf[p.offset+3]
p.offset += 4
return
}
func (p *entryParser) setOffset(off, max int) error {
// We can't go back
if off < p.offset {
return fmt.Errorf("invalid target offset (%d < %d): %w", off, p.offset, errMalformed)
}
// Ensure we don't go beyond our maximum, if given
if max >= 0 && off >= max {
return fmt.Errorf("invalid target offset (%d >= %d): %w", off, max, errMalformed)
}
// If we aren't already at this offset, move forward
if p.offset < off {
if p.checkExtraBytes {
extraData := p.buf[p.offset:off]
diff := off - p.offset
p.logf("%d bytes (%d, %d) are unused: %s", diff, p.offset, off, hex.EncodeToString(extraData))
}
p.offset = off
}
return nil
}
var (
errBufferTooSmall = errors.New("buffer too small")
errMalformed = errors.New("data malformed")
)
type entry struct {
Offset int
IP iptip
NFCache uint32
PacketCount uint64
ByteCount uint64
Matches []match
Target target
}
func (e entry) String() string {
var sb strings.Builder
sb.WriteString("{")
fmt.Fprintf(&sb, "Offset:%d IP:%v PacketCount:%d ByteCount:%d", e.Offset, e.IP, e.PacketCount, e.ByteCount)
if len(e.Matches) > 0 {
fmt.Fprintf(&sb, " Matches:%v", e.Matches)
}
fmt.Fprintf(&sb, " Target:%v", e.Target)
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseEntry(b []byte) (entry, error) {
startOff := p.offset
iptip, err := p.parseIPTIP()
if err != nil {
return entry{}, fmt.Errorf("parsing IPTIP: %w", err)
}
ret := entry{
Offset: startOff,
IP: iptip,
}
// Must have space for the rest of the members
if err := p.assertLen(28); err != nil {
return entry{}, err
}
ret.NFCache = native.Endian.Uint32(p.getBytes(4))
targetOffset := int(native.Endian.Uint16(p.getBytes(2)))
nextOffset := int(native.Endian.Uint16(p.getBytes(2)))
/* unused field: Comeback = */ p.getBytes(4)
ret.PacketCount = native.Endian.Uint64(p.getBytes(8))
ret.ByteCount = native.Endian.Uint64(p.getBytes(8))
// Must have at least enough space in our buffer to get to the target;
// doing this here means we can avoid bounds checks in parseMatches
if err := p.assertLen(targetOffset - p.offset); err != nil {
return entry{}, err
}
// Matches are stored between the end of the entry structure and the
// start of the 'targets' structure.
ret.Matches, err = p.parseMatches(targetOffset)
if err != nil {
return entry{}, err
}
if targetOffset > 0 {
if err := p.setOffset(targetOffset, nextOffset); err != nil {
return entry{}, err
}
ret.Target, err = p.parseTarget(nextOffset)
if err != nil {
return entry{}, fmt.Errorf("parsing target: %w", err)
}
}
if err := p.setOffset(nextOffset, -1); err != nil {
return entry{}, err
}
return ret, nil
}
type iptip struct {
Src netip.Addr
Dst netip.Addr
SrcMask netip.Addr
DstMask netip.Addr
InputInterface string
OutputInterface string
InputInterfaceMask []byte
OutputInterfaceMask []byte
Protocol uint16
Flags uint8
InverseFlags uint8
}
var protocolNames = map[uint16]string{
unix.IPPROTO_ESP: "esp",
unix.IPPROTO_GRE: "gre",
unix.IPPROTO_ICMP: "icmp",
unix.IPPROTO_ICMPV6: "icmpv6",
unix.IPPROTO_IGMP: "igmp",
unix.IPPROTO_IP: "ip",
unix.IPPROTO_IPIP: "ipip",
unix.IPPROTO_IPV6: "ip6",
unix.IPPROTO_RAW: "raw",
unix.IPPROTO_TCP: "tcp",
unix.IPPROTO_UDP: "udp",
}
func (ip iptip) String() string {
var sb strings.Builder
sb.WriteString("{")
formatAddrMask := func(addr, mask netip.Addr) string {
if pref, ok := netaddr.FromStdIPNet(&net.IPNet{
IP: addr.AsSlice(),
Mask: mask.AsSlice(),
}); ok {
return fmt.Sprint(pref)
}
return fmt.Sprintf("%s/%s", addr, mask)
}
fmt.Fprintf(&sb, "Src:%s", formatAddrMask(ip.Src, ip.SrcMask))
fmt.Fprintf(&sb, ", Dst:%s", formatAddrMask(ip.Dst, ip.DstMask))
translateMask := func(mask []byte) string {
var ret []byte
for _, b := range mask {
if b != 0 {
ret = append(ret, 'X')
} else {
ret = append(ret, '.')
}
}
return string(ret)
}
if ip.InputInterface != "" {
fmt.Fprintf(&sb, ", InputInterface:%s/%s", ip.InputInterface, translateMask(ip.InputInterfaceMask))
}
if ip.OutputInterface != "" {
fmt.Fprintf(&sb, ", OutputInterface:%s/%s", ip.OutputInterface, translateMask(ip.OutputInterfaceMask))
}
if nm, ok := protocolNames[ip.Protocol]; ok {
fmt.Fprintf(&sb, ", Protocol:%s", nm)
} else {
fmt.Fprintf(&sb, ", Protocol:%d", ip.Protocol)
}
if ip.Flags != 0 {
fmt.Fprintf(&sb, ", Flags:%d", ip.Flags)
}
if ip.InverseFlags != 0 {
fmt.Fprintf(&sb, ", InverseFlags:%d", ip.InverseFlags)
}
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseIPTIP() (iptip, error) {
if err := p.assertLen(84); err != nil {
return iptip{}, err
}
var ret iptip
ret.Src = netip.AddrFrom4(p.get4())
ret.Dst = netip.AddrFrom4(p.get4())
ret.SrcMask = netip.AddrFrom4(p.get4())
ret.DstMask = netip.AddrFrom4(p.get4())
const IFNAMSIZ = 16
ret.InputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ))
ret.OutputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ))
ret.InputInterfaceMask = p.getBytes(IFNAMSIZ)
ret.OutputInterfaceMask = p.getBytes(IFNAMSIZ)
ret.Protocol = native.Endian.Uint16(p.getBytes(2))
ret.Flags = p.getByte()
ret.InverseFlags = p.getByte()
return ret, nil
}
type match struct {
Name string
Revision int
Data any
RawData []byte
}
func (m match) String() string {
return fmt.Sprintf("{Name:%s, Data:%v}", m.Name, m.Data)
}
type matchTCP struct {
SourcePortRange [2]uint16
DestPortRange [2]uint16
Option byte
FlagMask byte
FlagCompare byte
InverseFlags byte
}
func (m matchTCP) String() string {
var sb strings.Builder
sb.WriteString("{")
fmt.Fprintf(&sb, "SrcPort:%s, DstPort:%s",
formatPortRange(m.SourcePortRange),
formatPortRange(m.DestPortRange))
// TODO(andrew): format semantically
if m.Option != 0 {
fmt.Fprintf(&sb, ", Option:%d", m.Option)
}
if m.FlagMask != 0 {
fmt.Fprintf(&sb, ", FlagMask:%d", m.FlagMask)
}
if m.FlagCompare != 0 {
fmt.Fprintf(&sb, ", FlagCompare:%d", m.FlagCompare)
}
if m.InverseFlags != 0 {
fmt.Fprintf(&sb, ", InverseFlags:%d", m.InverseFlags)
}
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseMatches(maxOffset int) ([]match, error) {
const XT_EXTENSION_MAXNAMELEN = 29
const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1
var ret []match
for {
// If we don't have space for a single match structure, we're done
if p.offset+structSize > maxOffset {
break
}
var curr match
matchSize := int(native.Endian.Uint16(p.getBytes(2)))
curr.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN))
curr.Revision = int(p.getByte())
// The data size is the total match size minus what we've already consumed.
dataLen := matchSize - structSize
dataEnd := p.offset + dataLen
// If we don't have space for the match data, then there's something wrong
if dataEnd > maxOffset {
return nil, fmt.Errorf("out of space for match (%d > max %d): %w", dataEnd, maxOffset, errMalformed)
} else if dataEnd > len(p.buf) {
return nil, fmt.Errorf("out of space for match (%d > buf %d): %w", dataEnd, len(p.buf), errMalformed)
}
curr.RawData = p.getBytes(dataLen)
// TODO(andrew): more here; UDP, etc.
switch curr.Name {
case "tcp":
/*
struct xt_tcp {
__u16 spts[2]; // Source port range.
__u16 dpts[2]; // Destination port range.
__u8 option; // TCP Option iff non-zero
__u8 flg_mask; // TCP flags mask byte
__u8 flg_cmp; // TCP flags compare byte
__u8 invflags; // Inverse flags
};
*/
if len(curr.RawData) >= 12 {
curr.Data = matchTCP{
SourcePortRange: [...]uint16{
native.Endian.Uint16(curr.RawData[0:2]),
native.Endian.Uint16(curr.RawData[2:4]),
},
DestPortRange: [...]uint16{
native.Endian.Uint16(curr.RawData[4:6]),
native.Endian.Uint16(curr.RawData[6:8]),
},
Option: curr.RawData[8],
FlagMask: curr.RawData[9],
FlagCompare: curr.RawData[10],
InverseFlags: curr.RawData[11],
}
}
}
ret = append(ret, curr)
}
return ret, nil
}
type target struct {
Name string
Revision int
Data any
RawData []byte
}
func (t target) String() string {
return fmt.Sprintf("{Name:%s, Data:%v}", t.Name, t.Data)
}
func (p *entryParser) parseTarget(nextOffset int) (target, error) {
const XT_EXTENSION_MAXNAMELEN = 29
const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1
if err := p.assertLen(structSize); err != nil {
return target{}, err
}
var ret target
targetSize := int(native.Endian.Uint16(p.getBytes(2)))
ret.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN))
ret.Revision = int(p.getByte())
if targetSize > structSize {
dataLen := targetSize - structSize
if err := p.assertLen(dataLen); err != nil {
return target{}, err
}
ret.RawData = p.getBytes(dataLen)
}
// Special case; matches what iptables does
if ret.Name == "" {
ret.Name = "standard"
}
switch ret.Name {
case "standard":
if len(ret.RawData) >= 4 {
verdict := int32(native.Endian.Uint32(ret.RawData))
var info string
switch verdict {
case -1:
info = "DROP"
case -2:
info = "ACCEPT"
case -4:
info = "QUEUE"
case -5:
info = "RETURN"
case int32(nextOffset):
info = "FALLTHROUGH"
default:
info = fmt.Sprintf("JUMP(%d)", verdict)
}
ret.Data = standardTarget{Verdict: info}
}
case "ERROR":
ret.Data = errorTarget{
ErrorName: unix.ByteSliceToString(ret.RawData),
}
case "REJECT":
if len(ret.RawData) >= 4 {
ret.Data = rejectTarget{
With: rejectWith(native.Endian.Uint32(ret.RawData)),
}
}
case "MARK":
if len(ret.RawData) >= 8 {
mark := native.Endian.Uint32(ret.RawData[0:4])
mask := native.Endian.Uint32(ret.RawData[4:8])
var mode markMode
switch {
case mark == 0:
mode = markModeAnd
mark = ^mask
case mark == mask:
mode = markModeOr
case mask == 0:
mode = markModeXor
case mask == 0xffffffff:
mode = markModeSet
default:
// TODO(andrew): handle xset?
}
ret.Data = markTarget{
Mark: mark,
Mode: mode,
}
}
}
return ret, nil
}
// Various types for things in iptables-land follow.
type standardTarget struct {
Verdict string
}
type errorTarget struct {
ErrorName string
}
type rejectWith int
const (
rwIPT_ICMP_NET_UNREACHABLE rejectWith = iota
rwIPT_ICMP_HOST_UNREACHABLE
rwIPT_ICMP_PROT_UNREACHABLE
rwIPT_ICMP_PORT_UNREACHABLE
rwIPT_ICMP_ECHOREPLY
rwIPT_ICMP_NET_PROHIBITED
rwIPT_ICMP_HOST_PROHIBITED
rwIPT_TCP_RESET
rwIPT_ICMP_ADMIN_PROHIBITED
)
func (rw rejectWith) String() string {
switch rw {
case rwIPT_ICMP_NET_UNREACHABLE:
return "icmp-net-unreachable"
case rwIPT_ICMP_HOST_UNREACHABLE:
return "icmp-host-unreachable"
case rwIPT_ICMP_PROT_UNREACHABLE:
return "icmp-prot-unreachable"
case rwIPT_ICMP_PORT_UNREACHABLE:
return "icmp-port-unreachable"
case rwIPT_ICMP_ECHOREPLY:
return "icmp-echo-reply"
case rwIPT_ICMP_NET_PROHIBITED:
return "icmp-net-prohibited"
case rwIPT_ICMP_HOST_PROHIBITED:
return "icmp-host-prohibited"
case rwIPT_TCP_RESET:
return "tcp-reset"
case rwIPT_ICMP_ADMIN_PROHIBITED:
return "icmp-admin-prohibited"
default:
return "UNKNOWN"
}
}
type rejectTarget struct {
With rejectWith
}
type markMode byte
const (
markModeSet markMode = iota
markModeAnd
markModeOr
markModeXor
)
func (mm markMode) String() string {
switch mm {
case markModeSet:
return "set"
case markModeAnd:
return "and"
case markModeOr:
return "or"
case markModeXor:
return "xor"
default:
return "UNKNOWN"
}
}
type markTarget struct {
Mode markMode
Mark uint32
}