diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 584d466df..4cb83849b 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -5,35 +5,93 @@ package natlab import ( + "fmt" "sync" "time" "inet.af/netaddr" ) -type session struct { +// FirewallType is the type of filtering a stateful firewall +// does. Values express different modes defined by RFC 4787. +type FirewallType int + +const ( + // AddressAndPortDependentFirewall specifies a destination + // address-and-port dependent firewall. Outbound traffic to an + // ip:port authorizes traffic from that ip:port exactly, and + // nothing else. + AddressAndPortDependentFirewall FirewallType = iota + // AddressDependentFirewall specifies a destination address + // dependent firewall. Once outbound traffic has been seen to an + // IP address, that IP address can talk back from any port. + AddressDependentFirewall + // EndpointIndependentFirewall specifies a destination endpoint + // independent firewall. Once outbound traffic has been seen from + // a source, anyone can talk back to that source. + EndpointIndependentFirewall +) + +// fwKey is the lookup key for a firewall session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out +// some fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type fwKey struct { src netaddr.IPPort dst netaddr.IPPort } +// key returns an fwKey for the given src and dst, trimmed according +// to the FirewallType. fwKeys are always constructed from the +// "outbound" point of view (i.e. src is the "trusted" side of the +// world), it's the caller's responsibility to swap src and dst in the +// call to key when processing packets inbound from the "untrusted" +// world. +func (s FirewallType) key(src, dst netaddr.IPPort) fwKey { + k := fwKey{src: src} + switch s { + case EndpointIndependentFirewall: + case AddressDependentFirewall: + k.dst.IP = dst.IP + case AddressAndPortDependentFirewall: + k.dst = dst + default: + panic(fmt.Sprintf("unknown firewall selectivity %v", s)) + } + return k +} + +// DefaultSessionTimeout is the default timeout for a firewall +// session. +const DefaultSessionTimeout = 30 * time.Second + +// Firewall is a simple stateful firewall that allows all outbound +// traffic and filters inbound traffic based on recently seen outbound +// traffic. Its HandlePacket method should be attached to a Machine to +// give it a stateful firewall. type Firewall struct { - // TrustedInterface is the interface that's allowed to send - // anywhere. All other interfaces can only respond to traffic from - // TrustedInterface. - TrustedInterface *Interface // SessionTimeout is the lifetime of idle sessions in the firewall // state. Packets transiting from the TrustedInterface reset the - // session lifetime to SessionTimeout. + // session lifetime to SessionTimeout. If zero, + // DefaultSessionTimeout is used. SessionTimeout time.Duration + // Type specifies how precisely return traffic must match + // previously seen outbound traffic to be allowed. Defaults to + // AddressAndPortDependentFirewall. + Type FirewallType + // TrustedInterface is an optional interface that is considered + // trusted in addition to PacketConns local to the Machine. All + // other interfaces can only respond to traffic from + // TrustedInterface or the local host. + TrustedInterface *Interface // TimeNow is a function returning the current time. If nil, // time.Now is used. TimeNow func() time.Time - // TODO: tuple-ness pickiness: EIF, ADF, APDF // TODO: refresh directionality: outbound-only, both mu sync.Mutex - seen map[session]time.Time // session -> deadline + seen map[fwKey]time.Time // session -> deadline } func (f *Firewall) timeNow() time.Time { @@ -43,33 +101,28 @@ func (f *Firewall) timeNow() time.Time { return time.Now() } +// HandlePacket implements the PacketHandler type. func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { f.mu.Lock() defer f.mu.Unlock() if f.seen == nil { - f.seen = map[session]time.Time{} + f.seen = map[fwKey]time.Time{} } if f.SessionTimeout == 0 { f.SessionTimeout = 30 * time.Second } if inIf == f.TrustedInterface || inIf == nil { - sess := session{ - src: p.Src, - dst: p.Dst, - } - f.seen[sess] = f.timeNow().Add(f.SessionTimeout) + k := f.Type.key(p.Src, p.Dst) + f.seen[k] = f.timeNow().Add(f.SessionTimeout) p.Trace("firewall out ok") return Continue } else { // reverse src and dst because the session table is from the // POV of outbound packets. - sess := session{ - src: p.Dst, - dst: p.Src, - } + k := f.Type.key(p.Dst, p.Src) now := f.timeNow() - if now.After(f.seen[sess]) { + if now.After(f.seen[k]) { p.Trace("firewall drop") return Drop } diff --git a/tstest/natlab/natlab_test.go b/tstest/natlab/natlab_test.go index 10e335faa..bfe2b5d98 100644 --- a/tstest/natlab/natlab_test.go +++ b/tstest/natlab/natlab_test.go @@ -224,8 +224,6 @@ func TestPacketHandler(t *testing.T) { } func TestFirewall(t *testing.T) { - clock := &tstest.Clock{} - wan := NewInternet() lan := &Network{ Name: "lan", @@ -235,28 +233,84 @@ func TestFirewall(t *testing.T) { trust := m.Attach("trust", lan) untrust := m.Attach("untrust", wan) - f := &Firewall{ - TrustedInterface: trust, - SessionTimeout: 30 * time.Second, - TimeNow: clock.Now, - } - client := ipp("192.168.0.2:1234") serverA := ipp("2.2.2.2:5678") - serverB := ipp("7.7.7.7:9012") - tests := []struct { - iface *Interface - src, dst netaddr.IPPort - want PacketVerdict - }{ - {trust, client, serverA, Continue}, - {untrust, serverA, client, Continue}, - {untrust, serverA, client, Continue}, - {untrust, serverB, client, Drop}, - {trust, client, serverB, Continue}, - {untrust, serverB, client, Continue}, - } + serverB1 := ipp("7.7.7.7:9012") + serverB2 := ipp("7.7.7.7:3456") + t.Run("ip_port_dependent", func(t *testing.T) { + f := &Firewall{ + TrustedInterface: trust, + SessionTimeout: 30 * time.Second, + Type: AddressAndPortDependentFirewall, + } + testFirewall(t, f, []fwTest{ + // client -> A authorizes A -> client + {trust, client, serverA, Continue}, + {untrust, serverA, client, Continue}, + {untrust, serverA, client, Continue}, + + // B1 -> client fails until client -> B1 + {untrust, serverB1, client, Drop}, + {trust, client, serverB1, Continue}, + {untrust, serverB1, client, Continue}, + + // B2 -> client still fails + {untrust, serverB2, client, Drop}, + }) + }) + t.Run("ip_dependent", func(t *testing.T) { + f := &Firewall{ + TrustedInterface: trust, + SessionTimeout: 30 * time.Second, + Type: AddressDependentFirewall, + } + testFirewall(t, f, []fwTest{ + // client -> A authorizes A -> client + {trust, client, serverA, Continue}, + {untrust, serverA, client, Continue}, + {untrust, serverA, client, Continue}, + + // B1 -> client fails until client -> B1 + {untrust, serverB1, client, Drop}, + {trust, client, serverB1, Continue}, + {untrust, serverB1, client, Continue}, + + // B2 -> client also works now + {untrust, serverB2, client, Continue}, + }) + }) + t.Run("endpoint_independent", func(t *testing.T) { + f := &Firewall{ + TrustedInterface: trust, + SessionTimeout: 30 * time.Second, + Type: EndpointIndependentFirewall, + } + testFirewall(t, f, []fwTest{ + // client -> A authorizes A -> client + {trust, client, serverA, Continue}, + {untrust, serverA, client, Continue}, + {untrust, serverA, client, Continue}, + + // B1 -> client also works + {untrust, serverB1, client, Continue}, + + // B2 -> client also works + {untrust, serverB2, client, Continue}, + }) + }) +} + +type fwTest struct { + iface *Interface + src, dst netaddr.IPPort + want PacketVerdict +} + +func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { + t.Helper() + clock := &tstest.Clock{} + f.TimeNow = clock.Now for _, test := range tests { clock.Advance(time.Second) p := &Packet{