1package main
2
3import (
4	"archive/zip"
5	"errors"
6	"flag"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"os"
11	"path/filepath"
12	"strings"
13)
14
15const (
16	exitSuccess = 0
17	exitFailure = 1
18)
19
20var (
21	zipPath   string
22	paths     []string
23	prefix    string
24	recursive bool
25	stripPath bool
26
27	name string
28	f    *flag.FlagSet
29
30	errorNotRecursive = errors.New("")
31)
32
33func usage(code int) {
34	fmt.Fprintln(os.Stderr, "Usage:")
35	fmt.Fprintf(os.Stderr, "  %s [options] zipfile [file ...]\n", name)
36	fmt.Fprintln(os.Stderr, "Options:")
37	f.SetOutput(os.Stderr)
38	f.PrintDefaults()
39
40	os.Exit(code)
41}
42
43func maybeAddExt(name string) string {
44	if filepath.Ext(name) == ".zip" {
45		return name
46	} else {
47		return name + ".zip"
48	}
49}
50
51func fatal(format string, args ...interface{}) {
52	fmt.Fprintf(os.Stderr, format+"\n", args...)
53	os.Exit(exitFailure)
54}
55
56func zipifyPath(path string) string {
57	path = filepath.Clean(path)
58	volume := filepath.VolumeName(path)
59	if volume != "" {
60		path = path[len(volume):]
61	}
62
63	return filepath.ToSlash(strings.TrimPrefix(path, "/"))
64}
65
66type walkFn func(string, *os.File, os.FileInfo) error
67
68func walk(root string, fn walkFn) error {
69	file, err := os.Open(root)
70	if err != nil {
71		return err
72	}
73	defer file.Close()
74
75	info, err := file.Stat()
76	if err != nil {
77		return err
78	}
79
80	err = fn(root, file, info)
81	if err != nil {
82		return err
83	}
84
85	if info.IsDir() {
86		children, err := file.Readdirnames(0)
87		if err != nil {
88			return err
89		}
90
91		for _, child := range children {
92			err = walk(filepath.Join(root, child), fn)
93			if err != nil {
94				return err
95			}
96		}
97	}
98
99	return nil
100}
101
102func isRegular(info os.FileInfo) bool {
103	return info.Mode()&os.ModeType == 0
104}
105
106func compress() {
107	zipFile, err := os.Create(zipPath)
108	if err != nil {
109		fatal("Couldn't create output file: %s", err.Error())
110	}
111
112	defer zipFile.Close()
113
114	err = doCompress(zipFile)
115	if err != nil {
116		fatal("%s", err.Error())
117	}
118}
119
120func doCompress(zipFile *os.File) error {
121	zipInfo, err := zipFile.Stat()
122	if err != nil {
123		return err
124	}
125
126	zipWriter := zip.NewWriter(zipFile)
127	defer zipWriter.Close()
128
129	fn := func(path string, f *os.File, info os.FileInfo) error {
130		if os.SameFile(zipInfo, info) || !(isRegular(info) || info.IsDir()) {
131			fmt.Fprintf(os.Stderr, "skipping %s\n", path)
132			return nil
133		}
134
135		zippedPath := path
136
137		if stripPath {
138			zippedPath = filepath.Base(zippedPath)
139		}
140
141		if prefix != "" {
142			zippedPath = filepath.Join(prefix, zippedPath)
143		}
144		zippedPath = zipifyPath(zippedPath)
145
146		var w io.Writer
147
148		if (zippedPath != "" && zippedPath != ".") || !info.IsDir() {
149			if info.IsDir() {
150				zippedPath += "/"
151			}
152
153			fmt.Fprintf(os.Stderr, "adding: %s -> %s\n", path, zippedPath)
154
155			header, err := zip.FileInfoHeader(info)
156			if err != nil {
157				return err
158			}
159			header.Name = zippedPath
160			header.Method = zip.Deflate
161
162			w, err = zipWriter.CreateHeader(header)
163			if err != nil {
164				return err
165			}
166
167		}
168
169		if info.IsDir() {
170			if recursive {
171				return nil
172			} else {
173				return errorNotRecursive
174			}
175		}
176
177		_, err = io.Copy(w, f)
178		if err != nil {
179			return fmt.Errorf("failed to copy %s: %s", path, err.Error())
180		}
181
182		return nil
183	}
184
185	for _, path := range paths {
186		err = walk(path, fn)
187		if err != nil && err != errorNotRecursive {
188			return err
189		}
190	}
191
192	return nil
193}
194
195func main() {
196	name = os.Args[0]
197
198	f = flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
199	f.SetOutput(ioutil.Discard)
200
201	f.BoolVar(&recursive, "recursive", false, "scan directories recursively")
202	f.BoolVar(&stripPath, "strip-path", false, "store just file names")
203	f.StringVar(&prefix, "prefix", "", "prepend each path with prefix")
204
205	err := f.Parse(os.Args[1:])
206	if err == nil {
207		switch {
208		case recursive && stripPath:
209			err = fmt.Errorf("-recursive and -strip-path are mutually exclusive")
210		case f.NArg() == 0:
211			err = fmt.Errorf("output zip file is not specified")
212		default:
213			zipPath = maybeAddExt(f.Arg(0))
214			paths = f.Args()[1:]
215		}
216	}
217
218	if err != nil {
219		switch err {
220		case flag.ErrHelp:
221			usage(exitSuccess)
222		default:
223			fmt.Fprintln(os.Stderr, err)
224			usage(exitFailure)
225		}
226	}
227
228	compress()
229}
230