394 lines
10 KiB
Go
394 lines
10 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// Package codegen contains shared utilities for generating code.
|
|
package codegen
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
"golang.org/x/tools/imports"
|
|
"tailscale.com/util/mak"
|
|
)
|
|
|
|
var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
|
|
|
|
// LoadTypes returns all named types in pkgName, keyed by their type name.
|
|
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) {
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
|
|
Tests: buildTags == "test",
|
|
}
|
|
if buildTags != "" && !cfg.Tests {
|
|
cfg.BuildFlags = []string{"-tags=" + buildTags}
|
|
}
|
|
|
|
pkgs, err := packages.Load(cfg, pkgName)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if cfg.Tests {
|
|
pkgs = testPackages(pkgs)
|
|
}
|
|
if len(pkgs) != 1 {
|
|
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
|
|
}
|
|
pkg := pkgs[0]
|
|
return pkg, namedTypes(pkg), nil
|
|
}
|
|
|
|
func testPackages(pkgs []*packages.Package) []*packages.Package {
|
|
var testPackages []*packages.Package
|
|
for _, pkg := range pkgs {
|
|
testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
|
|
if pkg.ID == testPackageID {
|
|
testPackages = append(testPackages, pkg)
|
|
}
|
|
}
|
|
return testPackages
|
|
}
|
|
|
|
// HasNoClone reports whether the provided tag has `codegen:noclone`.
|
|
func HasNoClone(structTag string) bool {
|
|
val := reflect.StructTag(structTag).Get("codegen")
|
|
for _, v := range strings.Split(val, ",") {
|
|
if v == "noclone" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
`
|
|
|
|
const genAndPackageHeader = `// Code generated by %v; DO NOT EDIT.
|
|
|
|
package %s
|
|
`
|
|
|
|
func NewImportTracker(thisPkg *types.Package) *ImportTracker {
|
|
return &ImportTracker{
|
|
thisPkg: thisPkg,
|
|
}
|
|
}
|
|
|
|
// ImportTracker provides a mechanism to track and build import paths.
|
|
type ImportTracker struct {
|
|
thisPkg *types.Package
|
|
packages map[string]bool
|
|
}
|
|
|
|
func (it *ImportTracker) Import(pkg string) {
|
|
if pkg != "" && !it.packages[pkg] {
|
|
mak.Set(&it.packages, pkg, true)
|
|
}
|
|
}
|
|
|
|
func (it *ImportTracker) qualifier(pkg *types.Package) string {
|
|
if it.thisPkg == pkg {
|
|
return ""
|
|
}
|
|
it.Import(pkg.Path())
|
|
// TODO(maisem): handle conflicts?
|
|
return pkg.Name()
|
|
}
|
|
|
|
// QualifiedName returns the string representation of t in the package.
|
|
func (it *ImportTracker) QualifiedName(t types.Type) string {
|
|
return types.TypeString(t, it.qualifier)
|
|
}
|
|
|
|
// PackagePrefix returns the prefix to be used when referencing named objects from pkg.
|
|
func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
|
|
if s := it.qualifier(pkg); s != "" {
|
|
return s + "."
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Write prints all the tracked imports in a single import block to w.
|
|
func (it *ImportTracker) Write(w io.Writer) {
|
|
fmt.Fprintf(w, "import (\n")
|
|
for s := range it.packages {
|
|
fmt.Fprintf(w, "\t%q\n", s)
|
|
}
|
|
fmt.Fprintf(w, ")\n\n")
|
|
}
|
|
|
|
func writeHeader(w io.Writer, tool, pkg string) {
|
|
if *flagCopyright {
|
|
fmt.Fprint(w, copyrightHeader)
|
|
}
|
|
fmt.Fprintf(w, genAndPackageHeader, tool, pkg)
|
|
}
|
|
|
|
// WritePackageFile adds a file with the provided imports and contents to package.
|
|
// The tool param is used to identify the tool that generated package file.
|
|
func WritePackageFile(tool string, pkg *packages.Package, path string, it *ImportTracker, contents *bytes.Buffer) error {
|
|
buf := new(bytes.Buffer)
|
|
writeHeader(buf, tool, pkg.Name)
|
|
it.Write(buf)
|
|
if _, err := buf.Write(contents.Bytes()); err != nil {
|
|
return err
|
|
}
|
|
return writeFormatted(buf.Bytes(), path)
|
|
}
|
|
|
|
// writeFormatted writes code to path.
|
|
// It runs gofmt on it before writing;
|
|
// if gofmt fails, it writes code unchanged.
|
|
// Errors can include I/O errors and gofmt errors.
|
|
//
|
|
// The advantage of always writing code to path,
|
|
// even if gofmt fails, is that it makes debugging easier.
|
|
// The code can be long, but you need it in order to debug.
|
|
// It is nicer to work with it in a file than a terminal.
|
|
// It is also easier to interpret gofmt errors
|
|
// with an editor providing file and line numbers.
|
|
func writeFormatted(code []byte, path string) error {
|
|
out, fmterr := imports.Process(path, code, &imports.Options{
|
|
Comments: true,
|
|
TabIndent: true,
|
|
TabWidth: 8,
|
|
FormatOnly: true, // fancy gofmt only
|
|
})
|
|
if fmterr != nil {
|
|
out = code
|
|
}
|
|
ioerr := os.WriteFile(path, out, 0644)
|
|
// Prefer I/O errors. They're usually easier to fix,
|
|
// and until they're fixed you can't do much else.
|
|
if ioerr != nil {
|
|
return ioerr
|
|
}
|
|
if fmterr != nil {
|
|
return fmt.Errorf("%s:%v", path, fmterr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// namedTypes returns all named types in pkg, keyed by their type name.
|
|
func namedTypes(pkg *packages.Package) map[string]types.Type {
|
|
nt := make(map[string]types.Type)
|
|
for _, file := range pkg.Syntax {
|
|
for _, d := range file.Decls {
|
|
decl, ok := d.(*ast.GenDecl)
|
|
if !ok || decl.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, s := range decl.Specs {
|
|
spec, ok := s.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch typ := typeNameObj.Type(); typ.(type) {
|
|
case *types.Alias, *types.Named:
|
|
nt[spec.Name.Name] = typ
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nt
|
|
}
|
|
|
|
// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
|
|
// thisPkg is the package containing t.
|
|
// tname is the named type corresponding to t.
|
|
// ctx is a single-word context for this assertion, such as "Clone".
|
|
// If non-nil, AssertStructUnchanged will add elements to imports
|
|
// for each package path that the caller must import for the returned code to compile.
|
|
func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
|
|
buf := new(bytes.Buffer)
|
|
w := func(format string, args ...any) {
|
|
fmt.Fprintf(buf, format+"\n", args...)
|
|
}
|
|
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
|
|
|
|
hasTypeParams := params != nil && params.Len() > 0
|
|
if hasTypeParams {
|
|
constraints, identifiers := FormatTypeParams(params, it)
|
|
w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
|
|
w("_%s%sNeedsRegeneration(struct {", tname, ctx)
|
|
} else {
|
|
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
|
|
}
|
|
|
|
for i := range t.NumFields() {
|
|
st := t.Field(i)
|
|
fname := st.Name()
|
|
ft := t.Field(i).Type()
|
|
if IsInvalid(ft) {
|
|
continue
|
|
}
|
|
qname := it.QualifiedName(ft)
|
|
var tag string
|
|
if hasTypeParams {
|
|
tag = t.Tag(i)
|
|
if tag != "" {
|
|
tag = "`" + tag + "`"
|
|
}
|
|
}
|
|
if st.Anonymous() {
|
|
w("\t%s %s", qname, tag)
|
|
} else {
|
|
w("\t%s %s %s", fname, qname, tag)
|
|
}
|
|
}
|
|
|
|
if hasTypeParams {
|
|
w("}{})\n}")
|
|
} else {
|
|
w("}{})")
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
// IsInvalid reports whether the provided type is invalid. It is used to allow
|
|
// codegeneration to run even when the target files have build errors or are
|
|
// missing views.
|
|
func IsInvalid(t types.Type) bool {
|
|
return t.String() == "invalid type"
|
|
}
|
|
|
|
// ContainsPointers reports whether typ contains any pointers,
|
|
// either explicitly or implicitly.
|
|
// It has special handling for some types that contain pointers
|
|
// that we know are free from memory aliasing/mutation concerns.
|
|
func ContainsPointers(typ types.Type) bool {
|
|
switch typ.String() {
|
|
case "time.Time":
|
|
// time.Time contains a pointer that does not need copying
|
|
return false
|
|
case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort":
|
|
return false
|
|
}
|
|
switch ft := typ.Underlying().(type) {
|
|
case *types.Array:
|
|
return ContainsPointers(ft.Elem())
|
|
case *types.Basic:
|
|
if ft.Kind() == types.UnsafePointer {
|
|
return true
|
|
}
|
|
case *types.Chan:
|
|
return true
|
|
case *types.Interface:
|
|
if ft.Empty() || ft.IsMethodSet() {
|
|
return true
|
|
}
|
|
for i := 0; i < ft.NumEmbeddeds(); i++ {
|
|
if ContainsPointers(ft.EmbeddedType(i)) {
|
|
return true
|
|
}
|
|
}
|
|
case *types.Map:
|
|
return true
|
|
case *types.Pointer:
|
|
return true
|
|
case *types.Slice:
|
|
return true
|
|
case *types.Struct:
|
|
for i := range ft.NumFields() {
|
|
if ContainsPointers(ft.Field(i).Type()) {
|
|
return true
|
|
}
|
|
}
|
|
case *types.Union:
|
|
for i := range ft.Len() {
|
|
if ContainsPointers(ft.Term(i).Type()) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// IsViewType reports whether the provided typ is a View.
|
|
func IsViewType(typ types.Type) bool {
|
|
t, ok := typ.Underlying().(*types.Struct)
|
|
if !ok {
|
|
return false
|
|
}
|
|
if t.NumFields() != 1 {
|
|
return false
|
|
}
|
|
return t.Field(0).Name() == "ж"
|
|
}
|
|
|
|
// FormatTypeParams formats the specified params and returns two strings:
|
|
// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
|
|
// - names are comma-separated type parameter names in square brackets (e.g. [T, V])
|
|
//
|
|
// If params is nil or empty, both return values are empty strings.
|
|
func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
|
|
if params == nil || params.Len() == 0 {
|
|
return "", ""
|
|
}
|
|
var constraintList, nameList []string
|
|
for i := range params.Len() {
|
|
param := params.At(i)
|
|
name := param.Obj().Name()
|
|
constraint := it.QualifiedName(param.Constraint())
|
|
nameList = append(nameList, name)
|
|
constraintList = append(constraintList, name+" "+constraint)
|
|
}
|
|
constraints = "[" + strings.Join(constraintList, ", ") + "]"
|
|
names = "[" + strings.Join(nameList, ", ") + "]"
|
|
return constraints, names
|
|
}
|
|
|
|
// LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
|
|
func LookupMethod(t types.Type, name string) *types.Func {
|
|
switch t := t.(type) {
|
|
case *types.Alias:
|
|
return LookupMethod(t.Rhs(), name)
|
|
case *types.TypeParam:
|
|
return LookupMethod(t.Constraint(), name)
|
|
case *types.Pointer:
|
|
return LookupMethod(t.Elem(), name)
|
|
case *types.Named:
|
|
switch u := t.Underlying().(type) {
|
|
case *types.Interface:
|
|
return LookupMethod(u, name)
|
|
default:
|
|
for i := 0; i < t.NumMethods(); i++ {
|
|
if method := t.Method(i); method.Name() == name {
|
|
return method
|
|
}
|
|
}
|
|
}
|
|
case *types.Interface:
|
|
for i := 0; i < t.NumMethods(); i++ {
|
|
if method := t.Method(i); method.Name() == name {
|
|
return method
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// NamedTypeOf is like t.(*types.Named), but also works with type aliases.
|
|
func NamedTypeOf(t types.Type) (named *types.Named, ok bool) {
|
|
if a, ok := t.(*types.Alias); ok {
|
|
return NamedTypeOf(types.Unalias(a))
|
|
}
|
|
named, ok = t.(*types.Named)
|
|
return
|
|
}
|