1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build !plan9
6
7package main
8
9import (
10	"bufio"
11	"context"
12	"fmt"
13	"os"
14	"os/user"
15	"path/filepath"
16	"runtime"
17	"strings"
18)
19
20const (
21	bashConfig = ".bash_profile"
22	zshConfig  = ".zshrc"
23)
24
25// appendToPATH adds the given path to the PATH environment variable and
26// persists it for future sessions.
27func appendToPATH(value string) error {
28	if isInPATH(value) {
29		return nil
30	}
31	return persistEnvVar("PATH", pathVar+envSeparator+value)
32}
33
34func isInPATH(dir string) bool {
35	p := os.Getenv("PATH")
36
37	paths := strings.Split(p, envSeparator)
38	for _, d := range paths {
39		if d == dir {
40			return true
41		}
42	}
43
44	return false
45}
46
47func getHomeDir() (string, error) {
48	home := os.Getenv(homeKey)
49	if home != "" {
50		return home, nil
51	}
52
53	u, err := user.Current()
54	if err != nil {
55		return "", err
56	}
57	return u.HomeDir, nil
58}
59
60func checkStringExistsFile(filename, value string) (bool, error) {
61	file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
62	if err != nil {
63		if os.IsNotExist(err) {
64			return false, nil
65		}
66		return false, err
67	}
68	defer file.Close()
69
70	scanner := bufio.NewScanner(file)
71	for scanner.Scan() {
72		line := scanner.Text()
73		if line == value {
74			return true, nil
75		}
76	}
77
78	return false, scanner.Err()
79}
80
81func appendToFile(filename, value string) error {
82	verbosef("Adding %q to %s", value, filename)
83
84	ok, err := checkStringExistsFile(filename, value)
85	if err != nil {
86		return err
87	}
88	if ok {
89		// Nothing to do.
90		return nil
91	}
92
93	f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
94	if err != nil {
95		return err
96	}
97	defer f.Close()
98
99	_, err = f.WriteString(lineEnding + value + lineEnding)
100	return err
101}
102
103func isShell(name string) bool {
104	return strings.Contains(currentShell(), name)
105}
106
107// persistEnvVarWindows sets an environment variable in the Windows
108// registry.
109func persistEnvVarWindows(name, value string) error {
110	_, err := runCommand(context.Background(), "powershell", "-command",
111		fmt.Sprintf(`[Environment]::SetEnvironmentVariable("%s", "%s", "User")`, name, value))
112	return err
113}
114
115func persistEnvVar(name, value string) error {
116	if runtime.GOOS == "windows" {
117		if err := persistEnvVarWindows(name, value); err != nil {
118			return err
119		}
120
121		if isShell("cmd.exe") || isShell("powershell.exe") {
122			return os.Setenv(strings.ToUpper(name), value)
123		}
124		// User is in bash, zsh, etc.
125		// Also set the environment variable in their shell config.
126	}
127
128	rc, err := shellConfigFile()
129	if err != nil {
130		return err
131	}
132
133	line := fmt.Sprintf("export %s=%s", strings.ToUpper(name), value)
134	if err := appendToFile(rc, line); err != nil {
135		return err
136	}
137
138	return os.Setenv(strings.ToUpper(name), value)
139}
140
141func shellConfigFile() (string, error) {
142	home, err := getHomeDir()
143	if err != nil {
144		return "", err
145	}
146
147	switch {
148	case isShell("bash"):
149		return filepath.Join(home, bashConfig), nil
150	case isShell("zsh"):
151		return filepath.Join(home, zshConfig), nil
152	default:
153		return "", fmt.Errorf("%q is not a supported shell", currentShell())
154	}
155}
156