1//  Copyright (c) 2014 Couchbase, Inc.
2//  Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
3//  except in compliance with the License. You may obtain a copy of the License at
4//    http://www.apache.org/licenses/LICENSE-2.0
5//  Unless required by applicable law or agreed to in writing, software distributed under the
6//  License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
7//  either express or implied. See the License for the specific language governing permissions
8//  and limitations under the License.
9
10package security
11
12import (
13	"crypto/tls"
14	"fmt"
15	"io/ioutil"
16	"net"
17	"sync"
18	"sync/atomic"
19	"unsafe"
20
21	"github.com/couchbase/cbauth"
22	"github.com/couchbase/indexing/secondary/logging"
23)
24
25//////////////////////////////////////////////////////
26// CBAUTH security/encryption setting
27//////////////////////////////////////////////////////
28
29type SecuritySetting struct {
30	encryptionEnabled bool
31	disableNonSSLPort bool
32	certificate       *tls.Certificate
33	certInBytes       []byte
34	tlsPreference     *cbauth.TLSConfig
35}
36
37var pSecuritySetting unsafe.Pointer = unsafe.Pointer(new(SecuritySetting))
38
39func GetSecuritySetting() *SecuritySetting {
40	return (*SecuritySetting)(atomic.LoadPointer(&pSecuritySetting))
41}
42
43func UpdateSecuritySetting(s *SecuritySetting) {
44	atomic.StorePointer(&pSecuritySetting, unsafe.Pointer(s))
45}
46
47func EncryptionEnabled() bool {
48	setting := GetSecuritySetting()
49	if setting == nil {
50		return false
51	}
52	return setting.encryptionEnabled
53}
54
55func DisableNonSSLPort() bool {
56	setting := GetSecuritySetting()
57	if setting == nil {
58		return false
59	}
60	return setting.disableNonSSLPort
61}
62
63//////////////////////////////////////////////////////
64// Security Context
65//////////////////////////////////////////////////////
66
67type ConsoleLogger func(error)
68
69type SecurityContext struct {
70	//initialization
71	initializer   sync.Once
72	initializedCh chan bool
73	logger        ConsoleLogger
74	isInitialized int32
75
76	// certificate
77	certFile string
78	keyFile  string
79
80	// encryption for localhost
81	encryptLocalHost bool
82	localhosts       map[string]bool
83
84	// TLS port mapping
85	encryptPortMap unsafe.Pointer
86	encryptPorts   unsafe.Pointer
87
88	// notifier
89	mutex     sync.RWMutex
90	notifiers map[string]SecurityChangeNotifier
91}
92
93var pSecurityContext *SecurityContext
94var pContextInitializer sync.Once
95
96func init() {
97	pSecurityContext = &SecurityContext{
98		initializedCh: make(chan bool),
99		notifiers:     make(map[string]SecurityChangeNotifier),
100		localhosts:    make(map[string]bool),
101	}
102
103	emptyMap1 := make(map[string]string, 0)
104	atomic.StorePointer(&pSecurityContext.encryptPortMap, unsafe.Pointer(&emptyMap1))
105
106	emptyMap2 := make(map[string]bool, 0)
107	atomic.StorePointer(&pSecurityContext.encryptPorts, unsafe.Pointer(&emptyMap2))
108}
109
110func InitSecurityContext(logger ConsoleLogger, localhost string, certFile string, keyFile string, encryptLocalHost bool) (err error) {
111
112	pContextInitializer.Do(func() {
113		var ips map[string]bool
114		ips, err = buildLocalAddr(localhost)
115		if err != nil {
116			return
117		}
118
119		pSecurityContext = &SecurityContext{
120			logger:           logger,
121			certFile:         certFile,
122			keyFile:          keyFile,
123			initializedCh:    make(chan bool),
124			notifiers:        make(map[string]SecurityChangeNotifier),
125			encryptLocalHost: encryptLocalHost,
126			localhosts:       ips,
127		}
128
129		emptyMap1 := make(map[string]string, 0)
130		atomic.StorePointer(&pSecurityContext.encryptPortMap, unsafe.Pointer(&emptyMap1))
131
132		emptyMap2 := make(map[string]bool, 0)
133		atomic.StorePointer(&pSecurityContext.encryptPorts, unsafe.Pointer(&emptyMap2))
134
135		cbauth.RegisterConfigRefreshCallback(pSecurityContext.refresh)
136
137		<-pSecurityContext.initializedCh
138
139		logging.Infof("security context:  encryptLocalHost %v Local IP's: %v", encryptLocalHost, ips)
140	})
141
142	return
143}
144
145func InitSecurityContextForClient(logger ConsoleLogger, localhost string, certFile string, keyFile string, encryptLocalHost bool) (err error) {
146
147	pContextInitializer.Do(func() {
148		var ips map[string]bool
149		ips, err = buildLocalAddr(localhost)
150		if err != nil {
151			return
152		}
153
154		pSecurityContext.logger = logger
155		pSecurityContext.certFile = certFile
156		pSecurityContext.keyFile = keyFile
157		pSecurityContext.encryptLocalHost = encryptLocalHost
158		pSecurityContext.localhosts = ips
159	})
160
161	return
162}
163
164func Refresh(tlsConfig cbauth.TLSConfig, encryptConfig cbauth.ClusterEncryptionConfig, certFile string, keyFile string) {
165
166	logging.Infof("Recieve security change notification. encryption=%v", encryptConfig.EncryptData)
167
168	newSetting := &SecuritySetting{}
169
170	oldSetting := GetSecuritySetting()
171	if oldSetting != nil {
172		temp := *oldSetting
173		newSetting = &temp
174	}
175
176	newSetting.tlsPreference = &tlsConfig
177	newSetting.encryptionEnabled = encryptConfig.EncryptData
178	newSetting.disableNonSSLPort = encryptConfig.DisableNonSSLPorts
179
180	if err := pSecurityContext.refreshCert(certFile, keyFile, newSetting); err != nil {
181		logging.Errorf("error in reading certifcate %v", err)
182		return
183	}
184
185	if err := pSecurityContext.update(newSetting, true); err != nil {
186		logging.Errorf("Fail to update security setting %v", err)
187		return
188	}
189}
190
191func buildLocalAddr(localhost string) (map[string]bool, error) {
192
193	hostname, _, err := net.SplitHostPort(localhost)
194	if err != nil {
195		return nil, err
196	}
197
198	addrs, err := net.InterfaceAddrs()
199	if err != nil {
200		return nil, err
201	}
202
203	ips := make(map[string]bool)
204	for _, addr := range addrs {
205		var ip net.IP
206		switch v := addr.(type) {
207		case *net.IPNet:
208			ip = v.IP
209		case *net.IPAddr:
210			ip = v.IP
211		case *net.TCPAddr:
212			ip = v.IP
213		case *net.UDPAddr:
214			ip = v.IP
215		}
216
217		if ip != nil {
218			ips[ip.String()] = true
219		}
220	}
221	ips[hostname] = true
222
223	return ips, nil
224}
225
226func EncryptionRequired(host string, port string) bool {
227	if !pSecurityContext.initialized() {
228		return false
229	}
230
231	if !EncryptionEnabled() {
232		return false
233	}
234
235	if !encryptLocalHost() && IsLocal(host) {
236		// If it is local IP, then do not encrypt
237		return false
238	}
239
240	// encrypt only if port is a known TLS port or it has a TLS port mapping
241	return isTLSPort(port) || hasTLSPort(port)
242}
243
244func isTLSPort(port string) bool {
245	ports := GetEncryptPorts()
246	_, ok := ports[port]
247	return ok
248}
249
250func hasTLSPort(port string) bool {
251	mapping := GetEncryptPortMapping()
252	_, ok := mapping[port]
253	return ok
254}
255
256func (p *SecurityContext) initialized() bool {
257	return atomic.LoadInt32(&pSecurityContext.isInitialized) == 1
258}
259
260func (p *SecurityContext) setInitialized() {
261	atomic.StoreInt32(&pSecurityContext.isInitialized, 1)
262	logging.Infof("security context initialized")
263}
264
265//////////////////////////////////////////////////////
266// Handle Security Change
267//////////////////////////////////////////////////////
268
269func (p *SecurityContext) refresh(code uint64) error {
270
271	logging.Infof("Recieve security change notification. code %v", code)
272
273	newSetting := &SecuritySetting{}
274
275	oldSetting := GetSecuritySetting()
276	if oldSetting != nil {
277		temp := *oldSetting
278		newSetting = &temp
279	}
280
281	if code&cbauth.CFG_CHANGE_CERTS_TLSCONFIG != 0 {
282		if err := p.refreshConfig(newSetting); err != nil {
283			return err
284		}
285
286		if err := p.refreshCert(p.certFile, p.keyFile, newSetting); err != nil {
287			return err
288		}
289	}
290
291	if code&cbauth.CFG_CHANGE_CLUSTER_ENCRYPTION != 0 {
292		if err := p.refreshEncryption(newSetting); err != nil {
293			return err
294		}
295	}
296
297	return p.update(newSetting, code&cbauth.CFG_CHANGE_CERTS_TLSCONFIG != 0)
298}
299
300func (p *SecurityContext) update(newSetting *SecuritySetting, refreshCert bool) error {
301
302	hasEnabled := false
303	oldSetting := GetSecuritySetting()
304	if oldSetting != nil {
305		hasEnabled = oldSetting.encryptionEnabled
306	}
307	refreshEncrypt := hasEnabled || hasEnabled != newSetting.encryptionEnabled
308
309	UpdateSecuritySetting(newSetting)
310
311	if !refreshEncrypt && !refreshCert {
312		logging.Infof("encryption is not enabled or no certificate refresh.   Do not notify security change")
313		return nil
314	}
315
316	p.mutex.RLock()
317	defer p.mutex.RUnlock()
318
319	for key, notifier := range p.notifiers {
320		logging.Infof("Notify security setting change for %v", key)
321		if err := notifier(refreshCert, refreshEncrypt); err != nil {
322			err1 := fmt.Errorf("Fail to refresh security setting for %v: %v", key, err)
323			if p.logger != nil {
324				p.logger(err1)
325			}
326			logging.Fatalf(err1.Error())
327		}
328	}
329
330	p.initializer.Do(func() {
331		close(p.initializedCh)
332	})
333
334	return nil
335}
336
337func (p *SecurityContext) refreshConfig(setting *SecuritySetting) error {
338
339	newConfig, err := cbauth.GetTLSConfig()
340	if err != nil {
341		err1 := fmt.Errorf("Fail to refresh TLSConfig due to error: %v", err)
342		if p.logger != nil {
343			p.logger(err1)
344		}
345		logging.Fatalf(err1.Error())
346		return err
347	}
348
349	setting.tlsPreference = &newConfig
350
351	logging.Infof("TLS config refreshed successfully")
352
353	return nil
354}
355
356func (p *SecurityContext) refreshCert(certFile string, keyFile string, setting *SecuritySetting) error {
357
358	if len(certFile) == 0 || len(keyFile) == 0 {
359		logging.Warnf("certifcate location is missing.  Cannot refresh certifcate")
360		return nil
361	}
362
363	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
364	if err != nil {
365		err1 := fmt.Errorf("Fail to due generate SSL certificate: %v", err)
366		if p.logger != nil {
367			p.logger(err1)
368		}
369		logging.Fatalf(err1.Error())
370		return err
371	}
372
373	certInBytes, err := ioutil.ReadFile(certFile)
374	if err != nil {
375		err1 := fmt.Errorf("Fail to due load SSL certificate from file: %v", err)
376		if p.logger != nil {
377			p.logger(err1)
378		}
379		logging.Fatalf(err1.Error())
380		return err
381	}
382
383	setting.certificate = &cert
384	setting.certInBytes = certInBytes
385
386	logging.Infof("Certificate refreshed successfully")
387
388	return nil
389}
390
391func (p *SecurityContext) refreshEncryption(setting *SecuritySetting) error {
392
393	cfg, err := cbauth.GetClusterEncryptionConfig()
394	if err != nil {
395		err1 := fmt.Errorf("Fail to due load encryption config: %v", err)
396		if p.logger != nil {
397			p.logger(err1)
398		}
399		logging.Fatalf(err1.Error())
400		return err
401	}
402
403	setting.encryptionEnabled = cfg.EncryptData
404	setting.disableNonSSLPort = cfg.DisableNonSSLPorts
405
406	logging.Infof("Encryption config refresh successfully.   Encryption enabled=%v", setting.encryptionEnabled)
407
408	return nil
409}
410
411//////////////////////////////////////////////////////
412// Security Change Notifier
413//////////////////////////////////////////////////////
414
415type SecurityChangeNotifier func(refreshCert bool, refreshEncrypt bool) error
416
417func RegisterCallback(key string, cb SecurityChangeNotifier) {
418	pSecurityContext.mutex.Lock()
419	defer pSecurityContext.mutex.Unlock()
420
421	pSecurityContext.notifiers[key] = cb
422}
423
424//////////////////////////////////////////////////////
425// Encrypt Port Mapping
426// - provide mapping from non-SSL port to SSL port
427//////////////////////////////////////////////////////
428
429func SetEncryptPortMapping(mapping map[string]string) {
430	ports := make(map[string]bool)
431	for _, encrypted := range mapping {
432		ports[encrypted] = true
433	}
434
435	atomic.StorePointer(&pSecurityContext.encryptPortMap, unsafe.Pointer(&mapping))
436	atomic.StorePointer(&pSecurityContext.encryptPorts, unsafe.Pointer(&ports))
437
438	pSecurityContext.setInitialized()
439
440	logging.Infof("security port mapping updated : %v", mapping)
441}
442
443func GetEncryptPortMapping() map[string]string {
444	return *(*map[string]string)(atomic.LoadPointer(&pSecurityContext.encryptPortMap))
445}
446
447func GetEncryptPorts() map[string]bool {
448	return *(*map[string]bool)(atomic.LoadPointer(&pSecurityContext.encryptPorts))
449}
450
451func EncryptPortFromAddr(addr string) (string, string, string, error) {
452
453	host, port, err := net.SplitHostPort(addr)
454	if err != nil {
455		return addr, "", "", err
456	}
457
458	port = EncryptPort(host, port)
459	return net.JoinHostPort(host, port), host, port, nil
460}
461
462func EncryptPort(host string, port string) string {
463
464	if EncryptionRequired(host, port) {
465		mapping := GetEncryptPortMapping()
466		for port1, port2 := range mapping {
467			if port == port1 {
468				return port2
469			}
470		}
471	}
472
473	return port
474}
475
476//////////////////////////////////////////////////////
477// Skip Encryption on Localhost
478//////////////////////////////////////////////////////
479
480func encryptLocalHost() bool {
481
482	if pSecurityContext.encryptLocalHost {
483		return true
484	}
485
486	return DisableNonSSLPort()
487}
488
489func IsLocal(host string) bool {
490
491	// empty host is treated as unknown host, rather than localhost
492	if len(host) == 0 {
493		return false
494	}
495
496	localhosts := pSecurityContext.localhosts
497	if match, ok := localhosts[host]; ok {
498		return match
499	}
500
501	ips, err := net.LookupIP(host)
502	if err == nil {
503		for _, ip := range ips {
504			if match, ok := localhosts[ip.String()]; ok {
505				return match
506			}
507		}
508	}
509
510	return false
511}
512