tailscale/util/codegen/codegen.go

152 lines
4.2 KiB
Go
Raw Normal View History

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package codegen contains shared utilities for generating code.
package codegen
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"go/types"
"os"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
// WriteFormatted writes code to path.
// It runs gofmt on it before writing;
// if gofmt fails, it writes code unchanged.
// Errors can include I/O errors and gofmt errors.
//
// The advantage of always writing code to path,
// even if gofmt fails, is that it makes debugging easier.
// The code can be long, but you need it in order to debug.
// It is nicer to work with it in a file than a terminal.
// It is also easier to interpret gofmt errors
// with an editor providing file and line numbers.
func WriteFormatted(code []byte, path string) error {
out, fmterr := imports.Process(path, code, &imports.Options{
Comments: true,
TabIndent: true,
TabWidth: 8,
FormatOnly: true, // fancy gofmt only
})
if fmterr != nil {
out = code
}
ioerr := os.WriteFile(path, out, 0644)
// Prefer I/O errors. They're usually easier to fix,
// and until they're fixed you can't do much else.
if ioerr != nil {
return ioerr
}
if fmterr != nil {
return fmt.Errorf("%s:%v", path, fmterr)
}
return nil
}
// NamedTypes returns all named types in pkg, keyed by their type name.
func NamedTypes(pkg *packages.Package) map[string]*types.Named {
nt := make(map[string]*types.Named)
for _, file := range pkg.Syntax {
for _, d := range file.Decls {
decl, ok := d.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
continue
}
for _, s := range decl.Specs {
spec, ok := s.(*ast.TypeSpec)
if !ok {
continue
}
typeNameObj := pkg.TypesInfo.Defs[spec.Name]
typ, ok := typeNameObj.Type().(*types.Named)
if !ok {
continue
}
nt[spec.Name.Name] = typ
}
}
}
return nt
}
// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
// thisPkg is the package containing t.
// tname is the named type corresponding to t.
// ctx is a single-word context for this assertion, such as "Clone".
// If non-nil, AssertStructUnchanged will add elements to imports
// for each package path that the caller must import for the returned code to compile.
func AssertStructUnchanged(t *types.Struct, thisPkg *types.Package, tname, ctx string, imports map[string]struct{}) []byte {
buf := new(bytes.Buffer)
w := func(format string, args ...any) {
fmt.Fprintf(buf, format+"\n", args...)
}
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
for i := 0; i < t.NumFields(); i++ {
fname := t.Field(i).Name()
ft := t.Field(i).Type()
qname, imppath := importedName(ft, thisPkg)
if imppath != "" && imports != nil {
imports[imppath] = struct{}{}
}
w("\t%s %s", fname, qname)
}
w("}{})\n")
return buf.Bytes()
}
func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
qual := func(pkg *types.Package) string {
if thisPkg == pkg {
return ""
}
importPkg = pkg.Path()
return pkg.Name()
}
return types.TypeString(t, qual), importPkg
}
// ContainsPointers reports whether typ contains any pointers,
// either explicitly or implicitly.
// It has special handling for some types that contain pointers
// that we know are free from memory aliasing/mutation concerns.
func ContainsPointers(typ types.Type) bool {
switch typ.String() {
case "time.Time":
// time.Time contains a pointer that does not need copying
return false
case "inet.af/netaddr.IP":
return false
}
switch ft := typ.Underlying().(type) {
case *types.Array:
return ContainsPointers(ft.Elem())
case *types.Chan:
return true
case *types.Interface:
return true // a little too broad
case *types.Map:
return true
case *types.Pointer:
return true
case *types.Slice:
return true
case *types.Struct:
for i := 0; i < ft.NumFields(); i++ {
if ContainsPointers(ft.Field(i).Type()) {
return true
}
}
}
return false
}