1package common
2
3import "errors"
4import "expvar"
5import "fmt"
6import "io"
7import "io/ioutil"
8import "path/filepath"
9import "net"
10import "net/url"
11import "os"
12import "strconv"
13import "strings"
14import "net/http"
15import "net/http/pprof"
16import "runtime"
17import "hash/crc64"
18import "reflect"
19import "unsafe"
20import "regexp"
21import "time"
22import "math/big"
23
24import "github.com/couchbase/cbauth"
25import "github.com/couchbase/cbauth/cbauthimpl"
26import "github.com/couchbase/indexing/secondary/dcp"
27import "github.com/couchbase/indexing/secondary/dcp/transport/client"
28import "github.com/couchbase/indexing/secondary/logging"
29import "github.com/couchbase/indexing/secondary/security"
30
31const IndexNamePattern = "^[A-Za-z0-9#_-]+$"
32
33const (
34	MAX_AUTH_RETRIES = 10
35)
36
37var ErrInvalidIndexName = fmt.Errorf("Invalid index name")
38
39// ExcludeStrings will exclude strings in `excludes` from `strs`. preserves the
40// order of `strs` in the result.
41func ExcludeStrings(strs []string, excludes []string) []string {
42	cache := make(map[string]bool)
43	for _, s := range excludes {
44		cache[s] = true
45	}
46	ss := make([]string, 0, len(strs))
47	for _, s := range strs {
48		if _, ok := cache[s]; ok == false {
49			ss = append(ss, s)
50		}
51	}
52	return ss
53}
54
55// CommonStrings returns intersection of two set of strings.
56func CommonStrings(xs []string, ys []string) []string {
57	ss := make([]string, 0, len(xs))
58	cache := make(map[string]bool)
59	for _, x := range xs {
60		cache[x] = true
61	}
62	for _, y := range ys {
63		if _, ok := cache[y]; ok {
64			ss = append(ss, y)
65		}
66	}
67	return ss
68}
69
70// HasString does membership check for a string.
71func HasString(str string, strs []string) bool {
72	for _, s := range strs {
73		if str == s {
74			return true
75		}
76	}
77	return false
78}
79
80// ExcludeUint32 remove items from list.
81func ExcludeUint32(xs []uint32, from []uint32) []uint32 {
82	fromSubXs := make([]uint32, 0, len(from))
83	for _, num := range from {
84		if HasUint32(num, xs) == false {
85			fromSubXs = append(fromSubXs, num)
86		}
87	}
88	return fromSubXs
89}
90
91// ExcludeUint64 remove items from list.
92func ExcludeUint64(xs []uint64, from []uint64) []uint64 {
93	fromSubXs := make([]uint64, 0, len(from))
94	for _, num := range from {
95		if HasUint64(num, xs) == false {
96			fromSubXs = append(fromSubXs, num)
97		}
98	}
99	return fromSubXs
100}
101
102// RemoveUint32 delete `item` from list `xs`.
103func RemoveUint32(item uint32, xs []uint32) []uint32 {
104	ys := make([]uint32, 0, len(xs))
105	for _, x := range xs {
106		if x == item {
107			continue
108		}
109		ys = append(ys, x)
110	}
111	return ys
112}
113
114// RemoveUint16 delete `item` from list `xs`.
115func RemoveUint16(item uint16, xs []uint16) []uint16 {
116	ys := make([]uint16, 0, len(xs))
117	for _, x := range xs {
118		if x == item {
119			continue
120		}
121		ys = append(ys, x)
122	}
123	return ys
124}
125
126// RemoveString delete `item` from list `xs`.
127func RemoveString(item string, xs []string) []string {
128	ys := make([]string, 0, len(xs))
129	for _, x := range xs {
130		if x == item {
131			continue
132		}
133		ys = append(ys, x)
134	}
135	return ys
136}
137
138// HasUint32 does membership check for a uint32 integer.
139func HasUint32(item uint32, xs []uint32) bool {
140	for _, x := range xs {
141		if x == item {
142			return true
143		}
144	}
145	return false
146}
147
148// HasUint64 does membership check for a uint32 integer.
149func HasUint64(item uint64, xs []uint64) bool {
150	for _, x := range xs {
151		if x == item {
152			return true
153		}
154	}
155	return false
156}
157
158// FailsafeOp can be used by gen-server implementors to avoid infinitely
159// blocked API calls.
160func FailsafeOp(
161	reqch, respch chan []interface{},
162	cmd []interface{},
163	finch chan bool) ([]interface{}, error) {
164
165	select {
166	case reqch <- cmd:
167		if respch != nil {
168			select {
169			case resp := <-respch:
170				return resp, nil
171			case <-finch:
172				return nil, ErrorClosed
173			}
174		}
175	case <-finch:
176		return nil, ErrorClosed
177	}
178	return nil, nil
179}
180
181// FailsafeOpAsync is same as FailsafeOp that can be used for
182// asynchronous operation, that is, caller does not wait for response.
183func FailsafeOpAsync(
184	reqch chan []interface{}, cmd []interface{}, finch chan bool) error {
185
186	select {
187	case reqch <- cmd:
188	case <-finch:
189		return ErrorClosed
190	}
191	return nil
192}
193
194// FailsafeOpNoblock is same as FailsafeOpAsync that can be used for
195// non-blocking operation, that is, if `reqch` is full caller does not block.
196func FailsafeOpNoblock(
197	reqch chan []interface{}, cmd []interface{}, finch chan bool) error {
198
199	select {
200	case reqch <- cmd:
201	case <-finch:
202		return ErrorClosed
203	default:
204		return ErrorChannelFull
205	}
206	return nil
207}
208
209// OpError suppliments FailsafeOp used by gen-servers.
210func OpError(err error, vals []interface{}, idx int) error {
211	if err != nil {
212		return err
213	} else if vals != nil {
214		if vals[idx] != nil {
215			return vals[idx].(error)
216		} else {
217			return nil
218		}
219	}
220	return nil
221}
222
223// cbauth admin authentication helper
224// Uses default cbauth env variables internally to provide auth creds
225type CbAuthHandler struct {
226	Hostport string
227	Bucket   string
228}
229
230func (ah *CbAuthHandler) GetCredentials() (string, string) {
231
232	var u, p string
233
234	fn := func(r int, err error) error {
235		if r > 0 {
236			logging.Warnf("CbAuthHandler::GetCredentials error=%v Retrying (%d)", err, r)
237		}
238
239		u, p, err = cbauth.GetHTTPServiceAuth(ah.Hostport)
240		return err
241	}
242
243	rh := NewRetryHelper(MAX_AUTH_RETRIES, time.Second, 2, fn)
244	err := rh.Run()
245	if err != nil {
246		panic(err)
247	}
248
249	return u, p
250}
251
252func (ah *CbAuthHandler) AuthenticateMemcachedConn(host string, conn *memcached.Client) error {
253
254	var u, p string
255
256	fn := func(r int, err error) error {
257		if r > 0 {
258			logging.Warnf("CbAuthHandler::AuthenticateMemcachedConn error=%v Retrying (%d)", err, r)
259		}
260
261		u, p, err = cbauth.GetMemcachedServiceAuth(host)
262		return err
263	}
264
265	rh := NewRetryHelper(MAX_AUTH_RETRIES, time.Second*3, 1, fn)
266	err := rh.Run()
267	if err != nil {
268		return err
269	}
270
271	_, err = conn.Auth(u, p)
272	_, err = conn.SelectBucket(ah.Bucket)
273	return err
274}
275
276// GetKVAddrs gather the list of kvnode-address based on the latest vbmap.
277func GetKVAddrs(cluster, pooln, bucketn string) ([]string, error) {
278	b, err := ConnectBucket(cluster, pooln, bucketn)
279	if err != nil {
280		return nil, err
281	}
282	defer b.Close()
283
284	if err := b.Refresh(); err != nil {
285		logging.Errorf("GetKVAddrs, error during bucket.Refresh() for bucket: %v, err: %v", b.Name, err)
286		return nil, err
287	}
288
289	m, err := b.GetVBmap(nil)
290	if err != nil {
291		return nil, err
292	}
293
294	kvaddrs := make([]string, 0, len(m))
295	for kvaddr := range m {
296		kvaddrs = append(kvaddrs, kvaddr)
297	}
298	return kvaddrs, nil
299}
300
301// IsIPLocal return whether `ip` address is loopback address or
302// compares equal with local-IP-address.
303func IsIPLocal(ip string) bool {
304	netIP := net.ParseIP(ip)
305
306	// if loopback address, return true
307	if netIP.IsLoopback() {
308		return true
309	}
310
311	// compare with the local ip
312	if localIP, err := GetLocalIP(); err == nil {
313		if localIP.Equal(netIP) {
314			return true
315		}
316	}
317	return false
318}
319
320// GetLocalIP return the first external-IP4 configured for the first
321// interface connected to this node.
322func GetLocalIP() (net.IP, error) {
323	interfaces, err := net.Interfaces()
324	if err != nil {
325		return nil, err
326	}
327	for _, iface := range interfaces {
328		if (iface.Flags & net.FlagUp) == 0 {
329			continue // interface down
330		}
331		if (iface.Flags & net.FlagLoopback) != 0 {
332			continue // loopback interface
333		}
334		addrs, err := iface.Addrs()
335		if err != nil {
336			return nil, err
337		}
338		for _, addr := range addrs {
339			var ip net.IP
340			switch v := addr.(type) {
341			case *net.IPNet:
342				ip = v.IP
343			case *net.IPAddr:
344				ip = v.IP
345			}
346			if ip != nil && !ip.IsLoopback() {
347				if ip = ip.To4(); ip != nil {
348					return ip, nil
349				}
350			}
351		}
352	}
353	return nil, errors.New("cannot find local IP address")
354}
355
356// ExitOnStdinClose is exit handler to be used with ns-server.
357func ExitOnStdinClose() {
358	buf := make([]byte, 4)
359	for {
360		_, err := os.Stdin.Read(buf)
361		if err != nil {
362			if err == io.EOF {
363				time.Sleep(1 * time.Second)
364				os.Exit(0)
365			}
366
367			panic(fmt.Sprintf("Stdin: Unexpected error occured %v", err))
368		}
369	}
370}
371
372// GetColocatedHost find the server addr for localhost and return the same.
373func GetColocatedHost(cluster string) (string, error) {
374	// get vbmap from bucket connection.
375	bucket, err := ConnectBucket(cluster, "default", "default")
376	if err != nil {
377		return "", err
378	}
379	defer bucket.Close()
380
381	hostports := bucket.NodeAddresses()
382	serversM := make(map[string]string)
383	servers := make([]string, 0)
384	for _, hostport := range hostports {
385		host, _, err := net.SplitHostPort(hostport)
386		if err != nil {
387			return "", err
388		}
389		serversM[host] = hostport
390		servers = append(servers, host)
391	}
392
393	for _, server := range servers {
394		addrs, err := net.LookupIP(server)
395		if err != nil {
396			return "", err
397		}
398		for _, addr := range addrs {
399			if IsIPLocal(addr.String()) {
400				return serversM[server], nil
401			}
402		}
403	}
404	return "", errors.New("unknown host")
405}
406
407func CrashOnError(err error) {
408	if err != nil {
409		panic(err)
410	}
411}
412
413func ClusterAuthUrl(cluster string) (string, error) {
414
415	if strings.HasPrefix(cluster, "http") {
416		u, err := url.Parse(cluster)
417		if err != nil {
418			return "", err
419		}
420		cluster = u.Host
421	}
422
423	adminUser, adminPasswd, err := cbauth.GetHTTPServiceAuth(cluster)
424	if err != nil {
425		return "", err
426	}
427
428	clusterUrl := url.URL{
429		Scheme: "http",
430		Host:   cluster,
431		User:   url.UserPassword(adminUser, adminPasswd),
432	}
433
434	return clusterUrl.String(), nil
435}
436
437func ClusterUrl(cluster string) string {
438	host := cluster
439	if strings.HasPrefix(cluster, "http") {
440		u, err := url.Parse(cluster)
441		if err != nil {
442			panic(err) // TODO: should we panic ?
443		}
444		host = u.Host
445	}
446	clusterUrl := url.URL{
447		Scheme: "http",
448		Host:   host,
449	}
450
451	return clusterUrl.String()
452}
453
454func MaybeSetEnv(key, value string) string {
455	if s := os.Getenv(key); s != "" {
456		return s
457	}
458	os.Setenv(key, value)
459	return value
460}
461
462func EquivalentIP(
463	raddr string,
464	raddrs []string) (this string, other string, err error) {
465
466	host, port, err := net.SplitHostPort(raddr)
467	if err != nil {
468		return "", "", err
469	}
470
471	if host == "localhost" {
472		host = GetLocalIpAddr(IsIpv6())
473	}
474
475	netIP := net.ParseIP(host)
476
477	for _, raddr1 := range raddrs {
478		host1, port1, err := net.SplitHostPort(raddr1)
479		if err != nil {
480			return "", "", err
481		}
482
483		if host1 == "localhost" {
484			host1 = GetLocalIpAddr(IsIpv6())
485		}
486		netIP1 := net.ParseIP(host1)
487		// check whether ports are same.
488		if port != port1 {
489			continue
490		}
491		// check whether both are local-ip.
492		if IsIPLocal(host) && IsIPLocal(host1) {
493			return net.JoinHostPort(host, port),
494				net.JoinHostPort(host1, port), nil // raddr => raddr1
495		}
496		// check whether they are coming from the same remote.
497		if netIP.Equal(netIP1) {
498			return net.JoinHostPort(host, port),
499				net.JoinHostPort(host1, port1), nil // raddr == raddr1
500		}
501	}
502	return net.JoinHostPort(host, port),
503		net.JoinHostPort(host, port), nil
504}
505
506//---------------------
507// SDK bucket operation
508//---------------------
509
510// ConnectBucket will instantiate a couchbase-bucket instance with cluster.
511// caller's responsibility to close the bucket.
512func ConnectBucket(cluster, pooln, bucketn string) (*couchbase.Bucket, error) {
513	if strings.HasPrefix(cluster, "http") {
514		u, err := url.Parse(cluster)
515		if err != nil {
516			return nil, err
517		}
518		cluster = u.Host
519	}
520
521	ah := &CbAuthHandler{
522		Hostport: cluster,
523		Bucket:   bucketn,
524	}
525
526	couch, err := couchbase.ConnectWithAuth("http://"+cluster, ah)
527	if err != nil {
528		return nil, err
529	}
530	pool, err := couch.GetPool(pooln)
531	if err != nil {
532		return nil, err
533	}
534	bucket, err := pool.GetBucket(bucketn)
535	if err != nil {
536		return nil, err
537	}
538	return bucket, err
539}
540
541// MaxVbuckets return the number of vbuckets in bucket.
542func MaxVbuckets(bucket *couchbase.Bucket) (int, error) {
543	count := 0
544	m, err := bucket.GetVBmap(nil)
545	if err == nil {
546		for _, vbnos := range m {
547			count += len(vbnos)
548		}
549	}
550	return count, err
551}
552
553// BucketTs return bucket timestamp for all vbucket.
554func BucketTs(bucket *couchbase.Bucket, maxvb int) (seqnos, vbuuids []uint64, err error) {
555	seqnos = make([]uint64, maxvb)
556	vbuuids = make([]uint64, maxvb)
557	stats, err := bucket.GetStats("vbucket-details")
558	// for all nodes in cluster
559	for _, nodestat := range stats {
560		// for all vbuckets
561		for i := 0; i < maxvb; i++ {
562			vbno_str := strconv.Itoa(i)
563			vbstatekey := "vb_" + vbno_str
564			vbhseqkey := "vb_" + vbno_str + ":high_seqno"
565			vbuuidkey := "vb_" + vbno_str + ":uuid"
566			vbstate, ok := nodestat[vbstatekey]
567			highseqno_s, hseq_ok := nodestat[vbhseqkey]
568			vbuuid_s, uuid_ok := nodestat[vbuuidkey]
569			if ok && hseq_ok && uuid_ok && vbstate == "active" {
570				if uuid, err := strconv.ParseUint(vbuuid_s, 10, 64); err == nil {
571					vbuuids[i] = uuid
572				}
573				if s, err := strconv.ParseUint(highseqno_s, 10, 64); err == nil {
574					if s > seqnos[i] {
575						seqnos[i] = s
576					}
577				}
578			}
579		}
580	}
581	return seqnos, vbuuids, err
582}
583
584func IsAuthValid(r *http.Request) (cbauth.Creds, bool, error) {
585
586	creds, err := cbauth.AuthWebCreds(r)
587	if err != nil {
588		if strings.Contains(err.Error(), cbauthimpl.ErrNoAuth.Error()) {
589			return nil, false, nil
590		}
591		return nil, false, err
592	}
593
594	return creds, true, nil
595}
596
597func SetNumCPUs(percent int) int {
598	ncpu := percent / 100
599	if ncpu == 0 {
600		ncpu = runtime.NumCPU()
601	}
602	runtime.GOMAXPROCS(ncpu)
603	return ncpu
604}
605
606func IndexStatement(def IndexDefn, numPartitions int, numReplica int, printNodes bool) string {
607	var stmt string
608	primCreate := "CREATE PRIMARY INDEX `%s` ON `%s`"
609	secCreate := "CREATE INDEX `%s` ON `%s`(%s)"
610	where := " WHERE %s"
611	partition := " PARTITION BY hash(%s)"
612
613	getPartnStmt := func() string {
614		if len(def.PartitionKeys) > 0 {
615			exprs := ""
616			for _, exp := range def.PartitionKeys {
617				if exprs != "" {
618					exprs += ","
619				}
620				exprs += exp
621			}
622			return fmt.Sprintf(partition, exprs)
623		}
624		return ""
625	}
626
627	if def.IsPrimary {
628		stmt = fmt.Sprintf(primCreate, def.Name, def.Bucket)
629
630		stmt += getPartnStmt()
631
632	} else {
633		exprs := ""
634		for i, exp := range def.SecExprs {
635			if exprs != "" {
636				exprs += ","
637			}
638			exprs += exp
639			if def.Desc != nil && def.Desc[i] {
640				exprs += " DESC"
641			}
642		}
643		stmt = fmt.Sprintf(secCreate, def.Name, def.Bucket, exprs)
644		stmt += getPartnStmt()
645
646		if def.WhereExpr != "" {
647			stmt += fmt.Sprintf(where, def.WhereExpr)
648		}
649	}
650
651	withExpr := ""
652	/*
653		if def.Immutable {
654			withExpr += "\"immutable\":true"
655		}
656	*/
657
658	if def.Deferred {
659		if len(withExpr) != 0 {
660			withExpr += ","
661		}
662
663		withExpr += " \"defer_build\":true"
664	}
665
666	if def.RetainDeletedXATTR {
667		if len(withExpr) != 0 {
668			withExpr += ","
669		}
670
671		withExpr += " \"retain_deleted_xattr\":true"
672	}
673
674	if printNodes && len(def.Nodes) != 0 {
675		if len(withExpr) != 0 {
676			withExpr += ","
677		}
678		withExpr += " \"nodes\":[ "
679
680		for i, node := range def.Nodes {
681			withExpr += "\"" + node + "\""
682			if i < len(def.Nodes)-1 {
683				withExpr += ","
684			}
685		}
686
687		withExpr += " ]"
688	}
689
690	if numReplica == -1 {
691		numReplica = def.GetNumReplica()
692	}
693
694	if numReplica != 0 {
695		if len(withExpr) != 0 {
696			withExpr += ","
697		}
698
699		withExpr += fmt.Sprintf(" \"num_replica\":%v", numReplica)
700	}
701
702	if IsPartitioned(def.PartitionScheme) {
703		if len(withExpr) != 0 {
704			withExpr += ","
705		}
706
707		withExpr += fmt.Sprintf(" \"num_partition\":%v", numPartitions)
708	}
709
710	if len(withExpr) != 0 {
711		stmt += fmt.Sprintf(" WITH { %s }", withExpr)
712	}
713
714	return stmt
715}
716
717func LogRuntime() string {
718	n := runtime.NumCPU()
719	v := runtime.Version()
720	m := runtime.GOMAXPROCS(-1)
721	fmsg := "%v %v; cpus: %v; GOMAXPROCS: %v; version: %v"
722	return fmt.Sprintf(fmsg, runtime.GOARCH, runtime.GOOS, n, m, v)
723}
724
725func LogOs() string {
726	gid := os.Getgid()
727	uid := os.Getuid()
728	hostname, _ := os.Hostname()
729	return fmt.Sprintf("uid: %v; gid: %v; hostname: %v", uid, gid, hostname)
730}
731
732//
733// This method fetch the bucket UUID.  If this method return an error,
734// then it means that the node is not able to connect in order to fetch
735// bucket UUID.
736//
737func GetBucketUUID(cluster, bucket string) (string, error) {
738
739	url, err := ClusterAuthUrl(cluster)
740	if err != nil {
741		return BUCKET_UUID_NIL, err
742	}
743
744	cinfo, err := NewClusterInfoCache(url, "default")
745	if err != nil {
746		return BUCKET_UUID_NIL, err
747	}
748	cinfo.SetUserAgent("GetBucketUUID")
749
750	cinfo.Lock()
751	defer cinfo.Unlock()
752
753	if err := cinfo.Fetch(); err != nil {
754		return BUCKET_UUID_NIL, err
755	}
756
757	return cinfo.GetBucketUUID(bucket), nil
758}
759
760func FileSize(name string) (int64, error) {
761	f, err := os.Open(name)
762	if err != nil {
763		return 0, err
764	}
765	defer f.Close()
766
767	fi, err := f.Stat()
768	if err != nil {
769		return 0, err
770	}
771
772	return fi.Size(), nil
773}
774
775// HashVbuuid return crc64 value of list of 64-bit vbuuids.
776func HashVbuuid(vbuuids []uint64) uint64 {
777	var bytes []byte
778	vbuuids_sl := (*reflect.SliceHeader)(unsafe.Pointer(&vbuuids))
779	bytes_sl := (*reflect.SliceHeader)(unsafe.Pointer(&bytes))
780	bytes_sl.Data = vbuuids_sl.Data
781	bytes_sl.Len = vbuuids_sl.Len * 8
782	bytes_sl.Cap = vbuuids_sl.Cap * 8
783	return crc64.Checksum(bytes, crc64.MakeTable(crc64.ECMA))
784}
785
786func IsValidIndexName(n string) error {
787	valid, _ := regexp.MatchString(IndexNamePattern, n)
788	if !valid {
789		return ErrInvalidIndexName
790	}
791
792	return nil
793}
794
795func ComputeAvg(lastAvg, lastValue, currValue int64) int64 {
796	if lastValue == 0 {
797		return 0
798	}
799
800	diff := currValue - lastValue
801	// Compute avg for first time
802	if lastAvg == 0 {
803		return diff
804	}
805
806	return (diff + lastAvg) / 2
807}
808
809// Write to the admin console
810func Console(clusterAddr string, format string, v ...interface{}) error {
811	msg := fmt.Sprintf(format, v...)
812	values := url.Values{"message": {msg}, "logLevel": {"info"}, "component": {"indexing"}}
813	reader := strings.NewReader(values.Encode())
814
815	if !strings.HasPrefix(clusterAddr, "http://") {
816		clusterAddr = "http://" + clusterAddr
817	}
818	clusterAddr += "/_log"
819
820	params := &security.RequestParams{Timeout: time.Duration(10) * time.Second}
821	res, err := security.PostWithAuth(clusterAddr, "application/x-www-form-urlencoded", reader, params)
822	res.Body.Close()
823
824	return err
825}
826
827func CopyFile(dest, source string) (err error) {
828	var sf, df *os.File
829
830	defer func() {
831		if sf != nil {
832			sf.Close()
833		}
834		if df != nil {
835			df.Close()
836		}
837	}()
838
839	if sf, err = os.Open(source); err != nil {
840		return err
841	} else if IsPathExist(dest) {
842		return nil
843	} else if df, err = os.Create(dest); err != nil {
844		return err
845	} else if _, err = io.Copy(df, sf); err != nil {
846		return err
847	}
848
849	var info os.FileInfo
850	if info, err = os.Stat(source); err != nil {
851		return err
852	} else if err = os.Chmod(dest, info.Mode()); err != nil {
853		return err
854	}
855	return
856}
857
858// CopyDir compose destination path based on source and,
859//   - if dest is file, and path is reachable, it is a no-op.
860//   - if dest is file, and path is not reachable, create and copy.
861//   - if dest is dir, and path is reachable, recurse into the dir.
862//   - if dest is dir, and path is not reachable, create and recurse into the dir.
863func CopyDir(dest, source string) error {
864	var created bool
865
866	if fi, err := os.Stat(source); err != nil {
867		return err
868	} else if !fi.IsDir() {
869		return fmt.Errorf("source not a directory")
870	} else if IsPathExist(dest) == false {
871		created = true
872		if err := os.MkdirAll(dest, fi.Mode()); err != nil {
873			return err
874		}
875	}
876
877	var err error
878	defer func() {
879		// if copy failed in the middle and directory was created by us, clean.
880		if err != nil && created {
881			os.RemoveAll(dest)
882		}
883	}()
884
885	var entries []os.FileInfo
886	if entries, err = ioutil.ReadDir(source); err != nil {
887		return err
888	} else {
889		for _, entry := range entries {
890			s := filepath.Join(source, entry.Name())
891			d := filepath.Join(dest, entry.Name())
892			if entry.IsDir() {
893				if err = CopyDir(d, s); err != nil {
894					return err
895				}
896			} else if err = CopyFile(d, s); err != nil {
897				return err
898			}
899		}
900	}
901	return nil
902}
903
904func IsPathExist(path string) bool {
905	if _, err := os.Stat(path); err != nil {
906		return !os.IsNotExist(err)
907	}
908	return true
909}
910
911func DiskUsage(dir string) (int64, error) {
912	var sz int64
913	err := filepath.Walk(dir, func(_ string, fi os.FileInfo, err error) error {
914		if err != nil {
915			return err
916		}
917
918		if !fi.IsDir() {
919			sz += fi.Size()
920		}
921		return nil
922	})
923
924	if err != nil {
925		return 0, err
926	}
927
928	return sz, nil
929}
930
931func GenNextBiggerKey(b []byte, isPrimary bool) []byte {
932	var x big.Int
933	if !isPrimary {
934		// Remove last 1 byte terminator encoding
935		x.SetBytes(b[:len(b)-1])
936	} else {
937		x.SetBytes(b[:len(b)])
938	}
939	x.Add(&x, big.NewInt(1))
940	return x.Bytes()
941}
942
943func IsAllowed(creds cbauth.Creds, permissions []string, w http.ResponseWriter) bool {
944
945	allow := false
946	err := error(nil)
947
948	for _, permission := range permissions {
949		allow, err = creds.IsAllowed(permission)
950		if allow && err == nil {
951			break
952		}
953	}
954
955	if err != nil {
956		w.WriteHeader(http.StatusInternalServerError)
957		w.Write([]byte(err.Error()))
958		return false
959	}
960
961	if !allow {
962		w.WriteHeader(http.StatusUnauthorized)
963		w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
964		return false
965	}
966
967	return true
968}
969
970func IsAllAllowed(creds cbauth.Creds, permissions []string, w http.ResponseWriter) bool {
971
972	allow := true
973	err := error(nil)
974
975	for _, permission := range permissions {
976		allow, err = creds.IsAllowed(permission)
977		if !allow || err != nil {
978			break
979		}
980	}
981
982	if err != nil {
983		w.WriteHeader(http.StatusInternalServerError)
984		w.Write([]byte(err.Error()))
985		return false
986	}
987
988	if !allow {
989		w.WriteHeader(http.StatusUnauthorized)
990		w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
991		return false
992	}
993
994	return true
995}
996
997func ComputePercent(a, b int64) int64 {
998	if a+b > 0 {
999		return a * 100 / (a + b)
1000	}
1001
1002	return 0
1003}
1004
1005func SetIpv6(isIpv6 bool) {
1006	security.SetIpv6(isIpv6)
1007}
1008
1009func IsIpv6() bool {
1010	return security.IsIpv6()
1011}
1012
1013func validateAuth(w http.ResponseWriter, r *http.Request) bool {
1014	_, valid, err := IsAuthValid(r)
1015	if err != nil {
1016		w.WriteHeader(http.StatusBadRequest)
1017		w.Write([]byte(err.Error() + "\n"))
1018	} else if valid == false {
1019		w.WriteHeader(401)
1020		w.Write([]byte("401 Unauthorized\n"))
1021	}
1022	return valid
1023}
1024
1025func GrHandler(rw http.ResponseWriter, r *http.Request) {
1026
1027	valid := validateAuth(rw, r)
1028	if !valid {
1029		return
1030	}
1031
1032	hndlr := pprof.Handler("goroutine")
1033	hndlr.ServeHTTP(rw, r)
1034}
1035
1036func BlockHandler(rw http.ResponseWriter, r *http.Request) {
1037
1038	valid := validateAuth(rw, r)
1039	if !valid {
1040		return
1041	}
1042
1043	hndlr := pprof.Handler("block")
1044	hndlr.ServeHTTP(rw, r)
1045}
1046
1047func HeapHandler(rw http.ResponseWriter, r *http.Request) {
1048
1049	valid := validateAuth(rw, r)
1050	if !valid {
1051		return
1052	}
1053
1054	hndlr := pprof.Handler("heap")
1055	hndlr.ServeHTTP(rw, r)
1056}
1057
1058func TCHandler(rw http.ResponseWriter, r *http.Request) {
1059
1060	valid := validateAuth(rw, r)
1061	if !valid {
1062		return
1063	}
1064
1065	hndlr := pprof.Handler("threadcreate")
1066	hndlr.ServeHTTP(rw, r)
1067}
1068
1069func PProfHandler(rw http.ResponseWriter, r *http.Request) {
1070
1071	valid := validateAuth(rw, r)
1072	if !valid {
1073		return
1074	}
1075
1076	pprof.Index(rw, r)
1077}
1078
1079func ProfileHandler(rw http.ResponseWriter, r *http.Request) {
1080
1081	valid := validateAuth(rw, r)
1082	if !valid {
1083		return
1084	}
1085
1086	pprof.Profile(rw, r)
1087}
1088
1089func CmdlineHandler(rw http.ResponseWriter, r *http.Request) {
1090
1091	valid := validateAuth(rw, r)
1092	if !valid {
1093		return
1094	}
1095
1096	pprof.Cmdline(rw, r)
1097}
1098
1099func SymbolHandler(rw http.ResponseWriter, r *http.Request) {
1100
1101	valid := validateAuth(rw, r)
1102	if !valid {
1103		return
1104	}
1105
1106	pprof.Symbol(rw, r)
1107}
1108
1109func TraceHandler(rw http.ResponseWriter, r *http.Request) {
1110
1111	valid := validateAuth(rw, r)
1112	if !valid {
1113		return
1114	}
1115
1116	pprof.Trace(rw, r)
1117}
1118
1119func ExpvarHandler(rw http.ResponseWriter, r *http.Request) {
1120
1121	valid := validateAuth(rw, r)
1122	if !valid {
1123		return
1124	}
1125
1126	rw.Header().Set("Content-Type", "application/json; charset=utf-8")
1127	fmt.Fprintf(rw, "{\n")
1128	first := true
1129	expvar.Do(func(kv expvar.KeyValue) {
1130		if !first {
1131			fmt.Fprintf(rw, ",\n")
1132		}
1133		first = false
1134		fmt.Fprintf(rw, "%q: %s", kv.Key, kv.Value)
1135	})
1136	fmt.Fprintf(rw, "\n}\n")
1137}
1138