cmd/cloner, tailcfg: fix nil vs len 0 issues, add tests, use for Hostinfo

Also use go:generate and https://golang.org/s/generatedcode header style.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-07-27 10:40:34 -07:00 committed by Brad Fitzpatrick
parent 41d0c81859
commit ec4feaf31c
4 changed files with 66 additions and 24 deletions

View File

@ -123,7 +123,7 @@ const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserve
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// AUTO-GENERATED by: tailscale.com/cmd/cloner -type %s
// Code generated by tailscale.com/cmd/cloner -type %s; DO NOT EDIT.
package %s
@ -168,8 +168,8 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types
}
switch ft := ft.Underlying().(type) {
case *types.Slice:
n := importedName(ft.Elem())
if containsPointers(ft.Elem()) {
n := importedName(ft.Elem())
writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname)
writef("for i := range dst.%s {", fname)
if _, isPtr := ft.Elem().(*types.Pointer); isPtr {
@ -179,7 +179,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types
}
writef("}")
} else {
writef("dst.%s = append([]%s(nil), src.%s...)", fname, n, fname)
writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname)
}
case *types.Pointer:
if named, _ := ft.Elem().(*types.Named); named != nil && containsPointers(ft.Elem()) {

View File

@ -4,6 +4,8 @@
package tailcfg
//go:generate go run tailscale.com/cmd/cloner -type=User,Node,Hostinfo,NetInfo -output=tailcfg_clone.go
import (
"bytes"
"errors"
@ -371,20 +373,6 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool {
ni.LinkType == ni2.LinkType
}
// Clone makes a deep copy of Hostinfo.
// The result aliases no memory with the original.
//
// TODO: use cmd/cloner, reconcile len(0) vs. nil.
func (h *Hostinfo) Clone() (res *Hostinfo) {
res = new(Hostinfo)
*res = *h
res.RoutableIPs = append([]wgcfg.CIDR{}, h.RoutableIPs...)
res.Services = append([]Service{}, h.Services...)
res.NetInfo = h.NetInfo.Clone()
return res
}
// Equal reports whether h and h2 are equal.
func (h *Hostinfo) Equal(h2 *Hostinfo) bool {
if h == nil && h2 == nil {

View File

@ -2,12 +2,11 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// AUTO-GENERATED by tailscale.com/cmd/cloner -type User,Node,NetInfo
// Code generated by tailscale.com/cmd/cloner -type User,Node,Hostinfo,NetInfo; DO NOT EDIT.
package tailcfg
import (
"github.com/tailscale/wireguard-go/wgcfg"
"time"
)
@ -19,8 +18,8 @@ func (src *User) Clone() *User {
}
dst := new(User)
*dst = *src
dst.Logins = append([]LoginID(nil), src.Logins...)
dst.Roles = append([]RoleID(nil), src.Roles...)
dst.Logins = append(src.Logins[:0:0], src.Logins...)
dst.Roles = append(src.Roles[:0:0], src.Roles...)
return dst
}
@ -32,9 +31,9 @@ func (src *Node) Clone() *Node {
}
dst := new(Node)
*dst = *src
dst.Addresses = append([]wgcfg.CIDR(nil), src.Addresses...)
dst.AllowedIPs = append([]wgcfg.CIDR(nil), src.AllowedIPs...)
dst.Endpoints = append([]string(nil), src.Endpoints...)
dst.Addresses = append(src.Addresses[:0:0], src.Addresses...)
dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...)
dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...)
dst.Hostinfo = *src.Hostinfo.Clone()
if dst.LastSeen != nil {
dst.LastSeen = new(time.Time)
@ -43,6 +42,21 @@ func (src *Node) Clone() *Node {
return dst
}
// Clone makes a deep copy of Hostinfo.
// The result aliases no memory with the original.
func (src *Hostinfo) Clone() *Hostinfo {
if src == nil {
return nil
}
dst := new(Hostinfo)
*dst = *src
dst.RoutableIPs = append(src.RoutableIPs[:0:0], src.RoutableIPs...)
dst.RequestTags = append(src.RequestTags[:0:0], src.RequestTags...)
dst.Services = append(src.Services[:0:0], src.Services...)
dst.NetInfo = src.NetInfo.Clone()
return dst
}
// Clone makes a deep copy of NetInfo.
// The result aliases no memory with the original.
func (src *NetInfo) Clone() *NetInfo {

View File

@ -390,3 +390,43 @@ func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler
t.Errorf("mismatch after unmarshal")
}
}
func TestCloneUser(t *testing.T) {
tests := []struct {
name string
u *User
}{
{"nil_logins", &User{}},
{"zero_logins", &User{Logins: make([]LoginID, 0)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u2 := tt.u.Clone()
if !reflect.DeepEqual(tt.u, u2) {
t.Errorf("not equal")
}
})
}
}
func TestCloneNode(t *testing.T) {
tests := []struct {
name string
v *Node
}{
{"nil_fields", &Node{}},
{"zero_fields", &Node{
Addresses: make([]wgcfg.CIDR, 0),
AllowedIPs: make([]wgcfg.CIDR, 0),
Endpoints: make([]string, 0),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v2 := tt.v.Clone()
if !reflect.DeepEqual(tt.v, v2) {
t.Errorf("not equal")
}
})
}
}