1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Bundle creates a single-source-file version of a source package
6// suitable for inclusion in a particular target package.
7//
8// Usage:
9//
10//	bundle [-o file] [-dst path] [-pkg name] [-prefix p] [-import old=new] <src>
11//
12// The src argument specifies the import path of the package to bundle.
13// The bundling of a directory of source files into a single source file
14// necessarily imposes a number of constraints.
15// The package being bundled must not use cgo; must not use conditional
16// file compilation, whether with build tags or system-specific file names
17// like code_amd64.go; must not depend on any special comments, which
18// may not be preserved; must not use any assembly sources;
19// must not use renaming imports; and must not use reflection-based APIs
20// that depend on the specific names of types or struct fields.
21//
22// By default, bundle writes the bundled code to standard output.
23// If the -o argument is given, bundle writes to the named file
24// and also includes a ``//go:generate'' comment giving the exact
25// command line used, for regenerating the file with ``go generate.''
26//
27// Bundle customizes its output for inclusion in a particular package, the destination package.
28// By default bundle assumes the destination is the package in the current directory,
29// but the destination package can be specified explicitly using the -dst option,
30// which takes an import path as its argument.
31// If the source package imports the destination package, bundle will remove
32// those imports and rewrite any references to use direct references to the
33// corresponding symbols.
34// Bundle also must write a package declaration in the output and must
35// choose a name to use in that declaration.
36// If the -package option is given, bundle uses that name.
37// Otherwise, if the -dst option is given, bundle uses the last
38// element of the destination import path.
39// Otherwise, by default bundle uses the package name found in the
40// package sources in the current directory.
41//
42// To avoid collisions, bundle inserts a prefix at the beginning of
43// every package-level const, func, type, and var identifier in src's code,
44// updating references accordingly. The default prefix is the package name
45// of the source package followed by an underscore. The -prefix option
46// specifies an alternate prefix.
47//
48// Occasionally it is necessary to rewrite imports during the bundling
49// process. The -import option, which may be repeated, specifies that
50// an import of "old" should be rewritten to import "new" instead.
51//
52// Example
53//
54// Bundle archive/zip for inclusion in cmd/dist:
55//
56//	cd $GOROOT/src/cmd/dist
57//	bundle -o zip.go archive/zip
58//
59// Bundle golang.org/x/net/http2 for inclusion in net/http,
60// prefixing all identifiers by "http2" instead of "http2_",
61// and rewriting the import "golang.org/x/net/http2/hpack"
62// to "internal/golang.org/x/net/http2/hpack":
63//
64//	cd $GOROOT/src/net/http
65//	bundle -o h2_bundle.go \
66//		-prefix http2 \
67//		-import golang.org/x/net/http2/hpack=internal/golang.org/x/net/http2/hpack \
68//		golang.org/x/net/http2
69//
70// Two ways to update the http2 bundle:
71//
72//	go generate net/http
73//
74//	cd $GOROOT/src/net/http
75//	go generate
76//
77// Update both bundles, restricting ``go generate'' to running bundle commands:
78//
79//	go generate -run bundle cmd/dist net/http
80//
81package main
82
83import (
84	"bytes"
85	"flag"
86	"fmt"
87	"go/ast"
88	"go/build"
89	"go/format"
90	"go/parser"
91	"go/printer"
92	"go/token"
93	"go/types"
94	"io/ioutil"
95	"log"
96	"os"
97	"path"
98	"strconv"
99	"strings"
100
101	"golang.org/x/tools/go/loader"
102)
103
104var (
105	outputFile = flag.String("o", "", "write output to `file` (default standard output)")
106	dstPath    = flag.String("dst", "", "set destination import `path` (default taken from current directory)")
107	pkgName    = flag.String("pkg", "", "set destination package `name` (default taken from current directory)")
108	prefix     = flag.String("prefix", "", "set bundled identifier prefix to `p` (default source package name + \"_\")")
109	underscore = flag.Bool("underscore", false, "rewrite golang.org to golang_org in imports; temporary workaround for golang.org/issue/16333")
110
111	importMap = map[string]string{}
112)
113
114func init() {
115	flag.Var(flagFunc(addImportMap), "import", "rewrite import using `map`, of form old=new (can be repeated)")
116}
117
118func addImportMap(s string) {
119	if strings.Count(s, "=") != 1 {
120		log.Fatal("-import argument must be of the form old=new")
121	}
122	i := strings.Index(s, "=")
123	old, new := s[:i], s[i+1:]
124	if old == "" || new == "" {
125		log.Fatal("-import argument must be of the form old=new; old and new must be non-empty")
126	}
127	importMap[old] = new
128}
129
130func usage() {
131	fmt.Fprintf(os.Stderr, "Usage: bundle [options] <src>\n")
132	flag.PrintDefaults()
133}
134
135func main() {
136	log.SetPrefix("bundle: ")
137	log.SetFlags(0)
138
139	flag.Usage = usage
140	flag.Parse()
141	args := flag.Args()
142	if len(args) != 1 {
143		usage()
144		os.Exit(2)
145	}
146
147	if *dstPath != "" {
148		if *pkgName == "" {
149			*pkgName = path.Base(*dstPath)
150		}
151	} else {
152		wd, _ := os.Getwd()
153		pkg, err := build.ImportDir(wd, 0)
154		if err != nil {
155			log.Fatalf("cannot find package in current directory: %v", err)
156		}
157		*dstPath = pkg.ImportPath
158		if *pkgName == "" {
159			*pkgName = pkg.Name
160		}
161	}
162
163	code, err := bundle(args[0], *dstPath, *pkgName, *prefix)
164	if err != nil {
165		log.Fatal(err)
166	}
167	if *outputFile != "" {
168		err := ioutil.WriteFile(*outputFile, code, 0666)
169		if err != nil {
170			log.Fatal(err)
171		}
172	} else {
173		_, err := os.Stdout.Write(code)
174		if err != nil {
175			log.Fatal(err)
176		}
177	}
178}
179
180// isStandardImportPath is copied from cmd/go in the standard library.
181func isStandardImportPath(path string) bool {
182	i := strings.Index(path, "/")
183	if i < 0 {
184		i = len(path)
185	}
186	elem := path[:i]
187	return !strings.Contains(elem, ".")
188}
189
190var ctxt = &build.Default
191
192func bundle(src, dst, dstpkg, prefix string) ([]byte, error) {
193	// Load the initial package.
194	conf := loader.Config{ParserMode: parser.ParseComments, Build: ctxt}
195	conf.TypeCheckFuncBodies = func(p string) bool { return p == src }
196	conf.Import(src)
197
198	lprog, err := conf.Load()
199	if err != nil {
200		return nil, err
201	}
202
203	// Because there was a single Import call and Load succeeded,
204	// InitialPackages is guaranteed to hold the sole requested package.
205	info := lprog.InitialPackages()[0]
206	if prefix == "" {
207		pkgName := info.Files[0].Name.Name
208		prefix = pkgName + "_"
209	}
210
211	objsToUpdate := make(map[types.Object]bool)
212	var rename func(from types.Object)
213	rename = func(from types.Object) {
214		if !objsToUpdate[from] {
215			objsToUpdate[from] = true
216
217			// Renaming a type that is used as an embedded field
218			// requires renaming the field too. e.g.
219			// 	type T int // if we rename this to U..
220			// 	var s struct {T}
221			// 	print(s.T) // ...this must change too
222			if _, ok := from.(*types.TypeName); ok {
223				for id, obj := range info.Uses {
224					if obj == from {
225						if field := info.Defs[id]; field != nil {
226							rename(field)
227						}
228					}
229				}
230			}
231		}
232	}
233
234	// Rename each package-level object.
235	scope := info.Pkg.Scope()
236	for _, name := range scope.Names() {
237		rename(scope.Lookup(name))
238	}
239
240	var out bytes.Buffer
241
242	fmt.Fprintf(&out, "// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.\n")
243	if *outputFile != "" {
244		fmt.Fprintf(&out, "//go:generate bundle %s\n", strings.Join(os.Args[1:], " "))
245	} else {
246		fmt.Fprintf(&out, "//   $ bundle %s\n", strings.Join(os.Args[1:], " "))
247	}
248	fmt.Fprintf(&out, "\n")
249
250	// Concatenate package comments from all files...
251	for _, f := range info.Files {
252		if doc := f.Doc.Text(); strings.TrimSpace(doc) != "" {
253			for _, line := range strings.Split(doc, "\n") {
254				fmt.Fprintf(&out, "// %s\n", line)
255			}
256		}
257	}
258	// ...but don't let them become the actual package comment.
259	fmt.Fprintln(&out)
260
261	fmt.Fprintf(&out, "package %s\n\n", dstpkg)
262
263	// BUG(adonovan,shurcooL): bundle may generate incorrect code
264	// due to shadowing between identifiers and imported package names.
265	//
266	// The generated code will either fail to compile or
267	// (unlikely) compile successfully but have different behavior
268	// than the original package. The risk of this happening is higher
269	// when the original package has renamed imports (they're typically
270	// renamed in order to resolve a shadow inside that particular .go file).
271
272	// TODO(adonovan,shurcooL):
273	// - detect shadowing issues, and either return error or resolve them
274	// - preserve comments from the original import declarations.
275
276	// pkgStd and pkgExt are sets of printed import specs. This is done
277	// to deduplicate instances of the same import name and path.
278	var pkgStd = make(map[string]bool)
279	var pkgExt = make(map[string]bool)
280	for _, f := range info.Files {
281		for _, imp := range f.Imports {
282			path, err := strconv.Unquote(imp.Path.Value)
283			if err != nil {
284				log.Fatalf("invalid import path string: %v", err) // Shouldn't happen here since conf.Load succeeded.
285			}
286			if path == dst {
287				continue
288			}
289			if newPath, ok := importMap[path]; ok {
290				path = newPath
291			}
292
293			var name string
294			if imp.Name != nil {
295				name = imp.Name.Name
296			}
297			spec := fmt.Sprintf("%s %q", name, path)
298			if isStandardImportPath(path) {
299				pkgStd[spec] = true
300			} else {
301				if *underscore {
302					spec = strings.Replace(spec, "golang.org/", "golang_org/", 1)
303				}
304				pkgExt[spec] = true
305			}
306		}
307	}
308
309	// Print a single declaration that imports all necessary packages.
310	fmt.Fprintln(&out, "import (")
311	for p := range pkgStd {
312		fmt.Fprintf(&out, "\t%s\n", p)
313	}
314	if len(pkgExt) > 0 {
315		fmt.Fprintln(&out)
316	}
317	for p := range pkgExt {
318		fmt.Fprintf(&out, "\t%s\n", p)
319	}
320	fmt.Fprint(&out, ")\n\n")
321
322	// Modify and print each file.
323	for _, f := range info.Files {
324		// Update renamed identifiers.
325		for id, obj := range info.Defs {
326			if objsToUpdate[obj] {
327				id.Name = prefix + obj.Name()
328			}
329		}
330		for id, obj := range info.Uses {
331			if objsToUpdate[obj] {
332				id.Name = prefix + obj.Name()
333			}
334		}
335
336		// For each qualified identifier that refers to the
337		// destination package, remove the qualifier.
338		// The "@@@." strings are removed in postprocessing.
339		ast.Inspect(f, func(n ast.Node) bool {
340			if sel, ok := n.(*ast.SelectorExpr); ok {
341				if id, ok := sel.X.(*ast.Ident); ok {
342					if obj, ok := info.Uses[id].(*types.PkgName); ok {
343						if obj.Imported().Path() == dst {
344							id.Name = "@@@"
345						}
346					}
347				}
348			}
349			return true
350		})
351
352		last := f.Package
353		if len(f.Imports) > 0 {
354			imp := f.Imports[len(f.Imports)-1]
355			last = imp.End()
356			if imp.Comment != nil {
357				if e := imp.Comment.End(); e > last {
358					last = e
359				}
360			}
361		}
362
363		// Pretty-print package-level declarations.
364		// but no package or import declarations.
365		var buf bytes.Buffer
366		for _, decl := range f.Decls {
367			if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT {
368				continue
369			}
370
371			beg, end := sourceRange(decl)
372
373			printComments(&out, f.Comments, last, beg)
374
375			buf.Reset()
376			format.Node(&buf, lprog.Fset, &printer.CommentedNode{Node: decl, Comments: f.Comments})
377			// Remove each "@@@." in the output.
378			// TODO(adonovan): not hygienic.
379			out.Write(bytes.Replace(buf.Bytes(), []byte("@@@."), nil, -1))
380
381			last = printSameLineComment(&out, f.Comments, lprog.Fset, end)
382
383			out.WriteString("\n\n")
384		}
385
386		printLastComments(&out, f.Comments, last)
387	}
388
389	// Now format the entire thing.
390	result, err := format.Source(out.Bytes())
391	if err != nil {
392		log.Fatalf("formatting failed: %v", err)
393	}
394
395	return result, nil
396}
397
398// sourceRange returns the [beg, end) interval of source code
399// belonging to decl (incl. associated comments).
400func sourceRange(decl ast.Decl) (beg, end token.Pos) {
401	beg = decl.Pos()
402	end = decl.End()
403
404	var doc, com *ast.CommentGroup
405
406	switch d := decl.(type) {
407	case *ast.GenDecl:
408		doc = d.Doc
409		if len(d.Specs) > 0 {
410			switch spec := d.Specs[len(d.Specs)-1].(type) {
411			case *ast.ValueSpec:
412				com = spec.Comment
413			case *ast.TypeSpec:
414				com = spec.Comment
415			}
416		}
417	case *ast.FuncDecl:
418		doc = d.Doc
419	}
420
421	if doc != nil {
422		beg = doc.Pos()
423	}
424	if com != nil && com.End() > end {
425		end = com.End()
426	}
427
428	return beg, end
429}
430
431func printComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos, end token.Pos) {
432	for _, cg := range comments {
433		if pos <= cg.Pos() && cg.Pos() < end {
434			for _, c := range cg.List {
435				fmt.Fprintln(out, c.Text)
436			}
437			fmt.Fprintln(out)
438		}
439	}
440}
441
442const infinity = 1 << 30
443
444func printLastComments(out *bytes.Buffer, comments []*ast.CommentGroup, pos token.Pos) {
445	printComments(out, comments, pos, infinity)
446}
447
448func printSameLineComment(out *bytes.Buffer, comments []*ast.CommentGroup, fset *token.FileSet, pos token.Pos) token.Pos {
449	tf := fset.File(pos)
450	for _, cg := range comments {
451		if pos <= cg.Pos() && tf.Line(cg.Pos()) == tf.Line(pos) {
452			for _, c := range cg.List {
453				fmt.Fprintln(out, c.Text)
454			}
455			return cg.End()
456		}
457	}
458	return pos
459}
460
461type flagFunc func(string)
462
463func (f flagFunc) Set(s string) error {
464	f(s)
465	return nil
466}
467
468func (f flagFunc) String() string { return "" }
469