147 lines
4.0 KiB
Go
147 lines
4.0 KiB
Go
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package codegen contains shared utilities for generating code.
|
|
package codegen
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/token"
|
|
"go/types"
|
|
"os"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
)
|
|
|
|
// WriteFormatted writes code to path.
|
|
// It runs gofmt on it before writing;
|
|
// if gofmt fails, it writes code unchanged.
|
|
// Errors can include I/O errors and gofmt errors.
|
|
//
|
|
// The advantage of always writing code to path,
|
|
// even if gofmt fails, is that it makes debugging easier.
|
|
// The code can be long, but you need it in order to debug.
|
|
// It is nicer to work with it in a file than a terminal.
|
|
// It is also easier to interpret gofmt errors
|
|
// with an editor providing file and line numbers.
|
|
func WriteFormatted(code []byte, path string) error {
|
|
out, fmterr := format.Source(code)
|
|
if fmterr != nil {
|
|
out = code
|
|
}
|
|
ioerr := os.WriteFile(path, out, 0644)
|
|
// Prefer I/O errors. They're usually easier to fix,
|
|
// and until they're fixed you can't do much else.
|
|
if ioerr != nil {
|
|
return ioerr
|
|
}
|
|
if fmterr != nil {
|
|
return fmt.Errorf("%s:%v", path, fmterr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// NamedTypes returns all named types in pkg, keyed by their type name.
|
|
func NamedTypes(pkg *packages.Package) map[string]*types.Named {
|
|
nt := make(map[string]*types.Named)
|
|
for _, file := range pkg.Syntax {
|
|
for _, d := range file.Decls {
|
|
decl, ok := d.(*ast.GenDecl)
|
|
if !ok || decl.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, s := range decl.Specs {
|
|
spec, ok := s.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeNameObj := pkg.TypesInfo.Defs[spec.Name]
|
|
typ, ok := typeNameObj.Type().(*types.Named)
|
|
if !ok {
|
|
continue
|
|
}
|
|
nt[spec.Name.Name] = typ
|
|
}
|
|
}
|
|
}
|
|
return nt
|
|
}
|
|
|
|
// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
|
|
// thisPkg is the package containing t.
|
|
// tname is the named type corresponding to t.
|
|
// ctx is a single-word context for this assertion, such as "Clone".
|
|
// If non-nil, AssertStructUnchanged will add elements to imports
|
|
// for each package path that the caller must import for the returned code to compile.
|
|
func AssertStructUnchanged(t *types.Struct, thisPkg *types.Package, tname, ctx string, imports map[string]struct{}) []byte {
|
|
buf := new(bytes.Buffer)
|
|
w := func(format string, args ...interface{}) {
|
|
fmt.Fprintf(buf, format+"\n", args...)
|
|
}
|
|
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
|
|
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
|
|
|
|
for i := 0; i < t.NumFields(); i++ {
|
|
fname := t.Field(i).Name()
|
|
ft := t.Field(i).Type()
|
|
qname, imppath := importedName(ft, thisPkg)
|
|
if imppath != "" && imports != nil {
|
|
imports[imppath] = struct{}{}
|
|
}
|
|
w("\t%s %s", fname, qname)
|
|
}
|
|
|
|
w("}{})\n")
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
|
|
qual := func(pkg *types.Package) string {
|
|
if thisPkg == pkg {
|
|
return ""
|
|
}
|
|
importPkg = pkg.Path()
|
|
return pkg.Name()
|
|
}
|
|
return types.TypeString(t, qual), importPkg
|
|
}
|
|
|
|
// ContainsPointers reports whether typ contains any pointers,
|
|
// either explicitly or implicitly.
|
|
// It has special handling for some types that contain pointers
|
|
// that we know are free from memory aliasing/mutation concerns.
|
|
func ContainsPointers(typ types.Type) bool {
|
|
switch typ.String() {
|
|
case "time.Time":
|
|
// time.Time contains a pointer that does not need copying
|
|
return false
|
|
case "inet.af/netaddr.IP":
|
|
return false
|
|
}
|
|
switch ft := typ.Underlying().(type) {
|
|
case *types.Array:
|
|
return ContainsPointers(ft.Elem())
|
|
case *types.Chan:
|
|
return true
|
|
case *types.Interface:
|
|
return true // a little too broad
|
|
case *types.Map:
|
|
return true
|
|
case *types.Pointer:
|
|
return true
|
|
case *types.Slice:
|
|
return true
|
|
case *types.Struct:
|
|
for i := 0; i < ft.NumFields(); i++ {
|
|
if ContainsPointers(ft.Field(i).Type()) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|