1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file is largely based on the Go 1.10-era cmd/go/internal/test/test.go
6// testmain generation code.
7
8package packages
9
10import (
11	"errors"
12	"fmt"
13	"go/ast"
14	"go/doc"
15	"go/parser"
16	"go/token"
17	"os"
18	"sort"
19	"strings"
20	"text/template"
21	"unicode"
22	"unicode/utf8"
23)
24
25// TODO(matloob): Delete this file once Go 1.12 is released.
26
27// This file complements golist_fallback.go by providing
28// support for generating testmains.
29
30func generateTestmain(out string, testPkg, xtestPkg *Package) (extraimports, extradeps []string, err error) {
31	testFuncs, err := loadTestFuncs(testPkg, xtestPkg)
32	if err != nil {
33		return nil, nil, err
34	}
35	extraimports = []string{"testing", "testing/internal/testdeps"}
36	if testFuncs.TestMain == nil {
37		extraimports = append(extraimports, "os")
38	}
39	// Transitive dependencies of ("testing", "testing/internal/testdeps").
40	// os is part of the transitive closure so it and its transitive dependencies are
41	// included regardless of whether it's imported in the template below.
42	extradeps = []string{
43		"errors",
44		"internal/cpu",
45		"unsafe",
46		"internal/bytealg",
47		"internal/race",
48		"runtime/internal/atomic",
49		"runtime/internal/sys",
50		"runtime",
51		"sync/atomic",
52		"sync",
53		"io",
54		"unicode",
55		"unicode/utf8",
56		"bytes",
57		"math",
58		"syscall",
59		"time",
60		"internal/poll",
61		"internal/syscall/unix",
62		"internal/testlog",
63		"os",
64		"math/bits",
65		"strconv",
66		"reflect",
67		"fmt",
68		"sort",
69		"strings",
70		"flag",
71		"runtime/debug",
72		"context",
73		"runtime/trace",
74		"testing",
75		"bufio",
76		"regexp/syntax",
77		"regexp",
78		"compress/flate",
79		"encoding/binary",
80		"hash",
81		"hash/crc32",
82		"compress/gzip",
83		"path/filepath",
84		"io/ioutil",
85		"text/tabwriter",
86		"runtime/pprof",
87		"testing/internal/testdeps",
88	}
89	return extraimports, extradeps, writeTestmain(out, testFuncs)
90}
91
92// The following is adapted from the cmd/go testmain generation code.
93
94// isTestFunc tells whether fn has the type of a testing function. arg
95// specifies the parameter type we look for: B, M or T.
96func isTestFunc(fn *ast.FuncDecl, arg string) bool {
97	if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
98		fn.Type.Params.List == nil ||
99		len(fn.Type.Params.List) != 1 ||
100		len(fn.Type.Params.List[0].Names) > 1 {
101		return false
102	}
103	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
104	if !ok {
105		return false
106	}
107	// We can't easily check that the type is *testing.M
108	// because we don't know how testing has been imported,
109	// but at least check that it's *M or *something.M.
110	// Same applies for B and T.
111	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg {
112		return true
113	}
114	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg {
115		return true
116	}
117	return false
118}
119
120// isTest tells whether name looks like a test (or benchmark, according to prefix).
121// It is a Test (say) if there is a character after Test that is not a lower-case letter.
122// We don't want TesticularCancer.
123func isTest(name, prefix string) bool {
124	if !strings.HasPrefix(name, prefix) {
125		return false
126	}
127	if len(name) == len(prefix) { // "Test" is ok
128		return true
129	}
130	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
131	return !unicode.IsLower(rune)
132}
133
134// loadTestFuncs returns the testFuncs describing the tests that will be run.
135func loadTestFuncs(ptest, pxtest *Package) (*testFuncs, error) {
136	t := &testFuncs{
137		TestPackage:  ptest,
138		XTestPackage: pxtest,
139	}
140	for _, file := range ptest.GoFiles {
141		if !strings.HasSuffix(file, "_test.go") {
142			continue
143		}
144		if err := t.load(file, "_test", &t.ImportTest, &t.NeedTest); err != nil {
145			return nil, err
146		}
147	}
148	if pxtest != nil {
149		for _, file := range pxtest.GoFiles {
150			if err := t.load(file, "_xtest", &t.ImportXtest, &t.NeedXtest); err != nil {
151				return nil, err
152			}
153		}
154	}
155	return t, nil
156}
157
158// writeTestmain writes the _testmain.go file for t to the file named out.
159func writeTestmain(out string, t *testFuncs) error {
160	f, err := os.Create(out)
161	if err != nil {
162		return err
163	}
164	defer f.Close()
165
166	if err := testmainTmpl.Execute(f, t); err != nil {
167		return err
168	}
169
170	return nil
171}
172
173type testFuncs struct {
174	Tests        []testFunc
175	Benchmarks   []testFunc
176	Examples     []testFunc
177	TestMain     *testFunc
178	TestPackage  *Package
179	XTestPackage *Package
180	ImportTest   bool
181	NeedTest     bool
182	ImportXtest  bool
183	NeedXtest    bool
184}
185
186// Tested returns the name of the package being tested.
187func (t *testFuncs) Tested() string {
188	return t.TestPackage.Name
189}
190
191type testFunc struct {
192	Package   string // imported package name (_test or _xtest)
193	Name      string // function name
194	Output    string // output, for examples
195	Unordered bool   // output is allowed to be unordered.
196}
197
198func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error {
199	var fset = token.NewFileSet()
200
201	f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
202	if err != nil {
203		return errors.New("failed to parse test file " + filename)
204	}
205	for _, d := range f.Decls {
206		n, ok := d.(*ast.FuncDecl)
207		if !ok {
208			continue
209		}
210		if n.Recv != nil {
211			continue
212		}
213		name := n.Name.String()
214		switch {
215		case name == "TestMain":
216			if isTestFunc(n, "T") {
217				t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
218				*doImport, *seen = true, true
219				continue
220			}
221			err := checkTestFunc(fset, n, "M")
222			if err != nil {
223				return err
224			}
225			if t.TestMain != nil {
226				return errors.New("multiple definitions of TestMain")
227			}
228			t.TestMain = &testFunc{pkg, name, "", false}
229			*doImport, *seen = true, true
230		case isTest(name, "Test"):
231			err := checkTestFunc(fset, n, "T")
232			if err != nil {
233				return err
234			}
235			t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
236			*doImport, *seen = true, true
237		case isTest(name, "Benchmark"):
238			err := checkTestFunc(fset, n, "B")
239			if err != nil {
240				return err
241			}
242			t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, "", false})
243			*doImport, *seen = true, true
244		}
245	}
246	ex := doc.Examples(f)
247	sort.Slice(ex, func(i, j int) bool { return ex[i].Order < ex[j].Order })
248	for _, e := range ex {
249		*doImport = true // import test file whether executed or not
250		if e.Output == "" && !e.EmptyOutput {
251			// Don't run examples with no output.
252			continue
253		}
254		t.Examples = append(t.Examples, testFunc{pkg, "Example" + e.Name, e.Output, e.Unordered})
255		*seen = true
256	}
257	return nil
258}
259
260func checkTestFunc(fset *token.FileSet, fn *ast.FuncDecl, arg string) error {
261	if !isTestFunc(fn, arg) {
262		name := fn.Name.String()
263		pos := fset.Position(fn.Pos())
264		return fmt.Errorf("%s: wrong signature for %s, must be: func %s(%s *testing.%s)", pos, name, name, strings.ToLower(arg), arg)
265	}
266	return nil
267}
268
269var testmainTmpl = template.Must(template.New("main").Parse(`
270package main
271
272import (
273{{if not .TestMain}}
274	"os"
275{{end}}
276	"testing"
277	"testing/internal/testdeps"
278
279{{if .ImportTest}}
280	{{if .NeedTest}}_test{{else}}_{{end}} {{.TestPackage.PkgPath | printf "%q"}}
281{{end}}
282{{if .ImportXtest}}
283	{{if .NeedXtest}}_xtest{{else}}_{{end}} {{.XTestPackage.PkgPath | printf "%q"}}
284{{end}}
285)
286
287var tests = []testing.InternalTest{
288{{range .Tests}}
289	{"{{.Name}}", {{.Package}}.{{.Name}}},
290{{end}}
291}
292
293var benchmarks = []testing.InternalBenchmark{
294{{range .Benchmarks}}
295	{"{{.Name}}", {{.Package}}.{{.Name}}},
296{{end}}
297}
298
299var examples = []testing.InternalExample{
300{{range .Examples}}
301	{"{{.Name}}", {{.Package}}.{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
302{{end}}
303}
304
305func init() {
306	testdeps.ImportPath = {{.TestPackage.PkgPath | printf "%q"}}
307}
308
309func main() {
310	m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, examples)
311{{with .TestMain}}
312	{{.Package}}.{{.Name}}(m)
313{{else}}
314	os.Exit(m.Run())
315{{end}}
316}
317
318`))
319