diff --git a/kube/egressservices/egressservices.go b/kube/egressservices/egressservices.go index 428b476b9..04a1c362b 100644 --- a/kube/egressservices/egressservices.go +++ b/kube/egressservices/egressservices.go @@ -9,11 +9,8 @@ package egressservices import ( - "encoding" - "fmt" + "encoding/json" "net/netip" - "strconv" - "strings" ) // KeyEgressServices is name of the proxy state Secret field that contains the @@ -31,7 +28,7 @@ type Config struct { // should be proxied. TailnetTarget TailnetTarget `json:"tailnetTarget"` // Ports contains mappings for ports that can be accessed on the tailnet target. - Ports map[PortMap]struct{} `json:"ports"` + Ports PortMaps `json:"ports"` } // TailnetTarget is the tailnet target to which traffic for the egress service @@ -52,35 +49,38 @@ type PortMap struct { TargetPort uint16 `json:"targetPort"` } -// PortMap is used as a Config.Ports map key. Config needs to be serialized/deserialized to/from JSON. JSON only -// supports string map keys, so we need to implement TextMarshaler/TextUnmarshaler to convert PortMap to string and -// back. -var _ encoding.TextMarshaler = PortMap{} -var _ encoding.TextUnmarshaler = &PortMap{} +type PortMaps map[PortMap]struct{} -func (pm *PortMap) UnmarshalText(t []byte) error { - tt := string(t) - ss := strings.Split(tt, ":") - if len(ss) != 3 { - return fmt.Errorf("error unmarshalling portmap from JSON, wants a portmap in form ::, got %q", tt) +// PortMaps is a list of PortMap structs, however, we want to use it as a set +// with efficient lookups in code. It implements custom JSON marshalling +// methods to convert between being a list in JSON and a set (map with empty +// values) in code. +var _ json.Marshaler = &PortMaps{} +var _ json.Marshaler = PortMaps{} +var _ json.Unmarshaler = &PortMaps{} + +func (p *PortMaps) UnmarshalJSON(data []byte) error { + *p = make(map[PortMap]struct{}) + + var l []PortMap + if err := json.Unmarshal(data, &l); err != nil { + return err } - pm.Protocol = ss[0] - matchPort, err := strconv.ParseUint(ss[1], 10, 16) - if err != nil { - return fmt.Errorf("error converting match port %q to uint16: %w", ss[1], err) + + for _, pm := range l { + (*p)[pm] = struct{}{} } - pm.MatchPort = uint16(matchPort) - targetPort, err := strconv.ParseUint(ss[2], 10, 16) - if err != nil { - return fmt.Errorf("error converting target port %q to uint16: %w", ss[2], err) - } - pm.TargetPort = uint16(targetPort) + return nil } -func (pm PortMap) MarshalText() ([]byte, error) { - s := fmt.Sprintf("%s:%d:%d", pm.Protocol, pm.MatchPort, pm.TargetPort) - return []byte(s), nil +func (p PortMaps) MarshalJSON() ([]byte, error) { + l := make([]PortMap, 0, len(p)) + for pm := range p { + l = append(l, pm) + } + + return json.Marshal(l) } // Status represents the currently configured firewall rules for all egress @@ -94,7 +94,7 @@ type Status struct { // ServiceStatus is the currently configured firewall rules for an egress // service. type ServiceStatus struct { - Ports map[PortMap]struct{} `json:"ports"` + Ports PortMaps `json:"ports"` // TailnetTargetIPs are the tailnet target IPs that were used to // configure these firewall rules. For a TailnetTarget with IP set, this // is the same as IP. diff --git a/kube/egressservices/egressservices_test.go b/kube/egressservices/egressservices_test.go index 5e5651e77..d6f952ea0 100644 --- a/kube/egressservices/egressservices_test.go +++ b/kube/egressservices/egressservices_test.go @@ -5,8 +5,9 @@ package egressservices import ( "encoding/json" - "reflect" "testing" + + "github.com/google/go-cmp/cmp" ) func Test_jsonUnmarshalConfig(t *testing.T) { @@ -18,7 +19,7 @@ func Test_jsonUnmarshalConfig(t *testing.T) { }{ { name: "success", - bs: []byte(`{"ports":{"tcp:4003:80":{}}}`), + bs: []byte(`{"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), wantsCfg: Config{Ports: map[PortMap]struct{}{{Protocol: "tcp", MatchPort: 4003, TargetPort: 80}: {}}}, }, { @@ -34,8 +35,8 @@ func Test_jsonUnmarshalConfig(t *testing.T) { if gotErr := json.Unmarshal(tt.bs, &cfg); (gotErr != nil) != tt.wantsErr { t.Errorf("json.Unmarshal returned error %v, wants error %v", gotErr, tt.wantsErr) } - if !reflect.DeepEqual(cfg, tt.wantsCfg) { - t.Errorf("json.Unmarshal produced Config %v, wants Config %v", cfg, tt.wantsCfg) + if diff := cmp.Diff(cfg, tt.wantsCfg); diff != "" { + t.Errorf("unexpected secrets (-got +want):\n%s", diff) } }) } @@ -54,12 +55,12 @@ func Test_jsonMarshalConfig(t *testing.T) { protocol: "tcp", matchPort: 4003, targetPort: 80, - wantsBs: []byte(`{"tailnetTarget":{"ip":"","fqdn":""},"ports":{"tcp:4003:80":{}}}`), + wantsBs: []byte(`{"tailnetTarget":{"ip":"","fqdn":""},"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg := Config{Ports: map[PortMap]struct{}{{ + cfg := Config{Ports: PortMaps{{ Protocol: tt.protocol, MatchPort: tt.matchPort, TargetPort: tt.targetPort}: {}}} @@ -68,8 +69,8 @@ func Test_jsonMarshalConfig(t *testing.T) { if gotErr != nil { t.Errorf("json.Marshal(%+#v) returned unexpected error %v", cfg, gotErr) } - if !reflect.DeepEqual(gotBs, tt.wantsBs) { - t.Errorf("json.Marshal(%+#v) returned '%v', wants '%v'", cfg, string(gotBs), string(tt.wantsBs)) + if diff := cmp.Diff(gotBs, tt.wantsBs); diff != "" { + t.Errorf("unexpected secrets (-got +want):\n%s", diff) } }) }