1package couchbase
2
3import (
4	"crypto/tls"
5	"errors"
6	"sync/atomic"
7	"time"
8
9	"github.com/couchbase/gomemcached"
10	"github.com/couchbase/gomemcached/client"
11	"github.com/couchbase/goutils/logging"
12)
13
14// GenericMcdAuthHandler is a kind of AuthHandler that performs
15// special auth exchange (like non-standard auth, possibly followed by
16// select-bucket).
17type GenericMcdAuthHandler interface {
18	AuthHandler
19	AuthenticateMemcachedConn(host string, conn *memcached.Client) error
20}
21
22// Error raised when a connection can't be retrieved from a pool.
23var TimeoutError = errors.New("timeout waiting to build connection")
24var errClosedPool = errors.New("the connection pool is closed")
25var errNoPool = errors.New("no connection pool")
26
27// Default timeout for retrieving a connection from the pool.
28var ConnPoolTimeout = time.Hour * 24 * 30
29
30// overflow connection closer cycle time
31var ConnCloserInterval = time.Second * 30
32
33// ConnPoolAvailWaitTime is the amount of time to wait for an existing
34// connection from the pool before considering the creation of a new
35// one.
36var ConnPoolAvailWaitTime = time.Millisecond
37
38type connectionPool struct {
39	host        string
40	mkConn      func(host string, ah AuthHandler, tlsConfig *tls.Config, bucketName string) (*memcached.Client, error)
41	auth        AuthHandler
42	connections chan *memcached.Client
43	createsem   chan bool
44	bailOut     chan bool
45	poolSize    int
46	connCount   uint64
47	inUse       bool
48	encrypted   bool
49	tlsConfig   *tls.Config
50	bucket      string
51}
52
53func newConnectionPool(host string, ah AuthHandler, closer bool, poolSize, poolOverflow int, tlsConfig *tls.Config, bucket string, encrypted bool) *connectionPool {
54	connSize := poolSize
55	if closer {
56		connSize += poolOverflow
57	}
58	rv := &connectionPool{
59		host:        host,
60		connections: make(chan *memcached.Client, connSize),
61		createsem:   make(chan bool, poolSize+poolOverflow),
62		mkConn:      defaultMkConn,
63		auth:        ah,
64		poolSize:    poolSize,
65		bucket:      bucket,
66		encrypted:   encrypted,
67	}
68
69	if encrypted {
70		rv.tlsConfig = tlsConfig
71	}
72
73	if closer {
74		rv.bailOut = make(chan bool, 1)
75		go rv.connCloser()
76	}
77	return rv
78}
79
80// ConnPoolTimeout is notified whenever connections are acquired from a pool.
81var ConnPoolCallback func(host string, source string, start time.Time, err error)
82
83// Use regular in-the-clear connection if tlsConfig is nil.
84// Use secure connection (TLS) if tlsConfig is set.
85func defaultMkConn(host string, ah AuthHandler, tlsConfig *tls.Config, bucketName string) (*memcached.Client, error) {
86	var features memcached.Features
87
88	var conn *memcached.Client
89	var err error
90	if tlsConfig == nil {
91		conn, err = memcached.Connect("tcp", host)
92	} else {
93		conn, err = memcached.ConnectTLS("tcp", host, tlsConfig)
94	}
95
96	if err != nil {
97		return nil, err
98	}
99
100	if DefaultTimeout > 0 {
101		conn.SetDeadline(getDeadline(noDeadline, DefaultTimeout))
102	}
103
104	if TCPKeepalive == true {
105		conn.SetKeepAliveOptions(time.Duration(TCPKeepaliveInterval) * time.Second)
106	}
107
108	if EnableMutationToken == true {
109		features = append(features, memcached.FeatureMutationToken)
110	}
111	if EnableDataType == true {
112		features = append(features, memcached.FeatureDataType)
113	}
114
115	if EnableXattr == true {
116		features = append(features, memcached.FeatureXattr)
117	}
118
119	if EnableCollections {
120		features = append(features, memcached.FeatureCollections)
121	}
122
123	if len(features) > 0 {
124		res, err := conn.EnableFeatures(features)
125		if err != nil && isTimeoutError(err) {
126			conn.Close()
127			return nil, err
128		}
129
130		if err != nil || res.Status != gomemcached.SUCCESS {
131			logging.Warnf("Unable to enable features %v", err)
132		}
133	}
134
135	if gah, ok := ah.(GenericMcdAuthHandler); ok {
136		err = gah.AuthenticateMemcachedConn(host, conn)
137		if err != nil {
138			conn.Close()
139			return nil, err
140		}
141
142		if DefaultTimeout > 0 {
143			conn.SetDeadline(noDeadline)
144		}
145
146		return conn, nil
147	}
148	name, pass, bucket := ah.GetCredentials()
149	if bucket == "" {
150		// Authenticator does not know specific bucket.
151		bucket = bucketName
152	}
153
154	if name != "default" {
155		_, err = conn.Auth(name, pass)
156		if err != nil {
157			conn.Close()
158			return nil, err
159		}
160		// Select bucket (Required for cb_auth creds)
161		// Required when doing auth with _admin credentials
162		if bucket != "" && bucket != name {
163			_, err = conn.SelectBucket(bucket)
164			if err != nil {
165				conn.Close()
166				return nil, err
167			}
168		}
169	}
170
171	if DefaultTimeout > 0 {
172		conn.SetDeadline(noDeadline)
173	}
174
175	return conn, nil
176}
177
178func (cp *connectionPool) Close() (err error) {
179	defer func() {
180		if recover() != nil {
181			err = errors.New("connectionPool.Close error")
182		}
183	}()
184	if cp.bailOut != nil {
185
186		// defensively, we won't wait if the channel is full
187		select {
188		case cp.bailOut <- false:
189		default:
190		}
191	}
192	close(cp.connections)
193	for c := range cp.connections {
194		c.Close()
195	}
196	return
197}
198
199func (cp *connectionPool) Node() string {
200	return cp.host
201}
202
203func (cp *connectionPool) GetWithTimeout(d time.Duration) (rv *memcached.Client, err error) {
204	if cp == nil {
205		return nil, errNoPool
206	}
207
208	path := ""
209
210	if ConnPoolCallback != nil {
211		defer func(path *string, start time.Time) {
212			ConnPoolCallback(cp.host, *path, start, err)
213		}(&path, time.Now())
214	}
215
216	path = "short-circuit"
217
218	// short-circuit available connetions.
219	select {
220	case rv, isopen := <-cp.connections:
221		if !isopen {
222			return nil, errClosedPool
223		}
224		atomic.AddUint64(&cp.connCount, 1)
225		return rv, nil
226	default:
227	}
228
229	t := time.NewTimer(ConnPoolAvailWaitTime)
230	defer t.Stop()
231
232	// Try to grab an available connection within 1ms
233	select {
234	case rv, isopen := <-cp.connections:
235		path = "avail1"
236		if !isopen {
237			return nil, errClosedPool
238		}
239		atomic.AddUint64(&cp.connCount, 1)
240		return rv, nil
241	case <-t.C:
242		// No connection came around in time, let's see
243		// whether we can get one or build a new one first.
244		t.Reset(d) // Reuse the timer for the full timeout.
245		select {
246		case rv, isopen := <-cp.connections:
247			path = "avail2"
248			if !isopen {
249				return nil, errClosedPool
250			}
251			atomic.AddUint64(&cp.connCount, 1)
252			return rv, nil
253		case cp.createsem <- true:
254			path = "create"
255			// Build a connection if we can't get a real one.
256			// This can potentially be an overflow connection, or
257			// a pooled connection.
258			rv, err := cp.mkConn(cp.host, cp.auth, cp.tlsConfig, cp.bucket)
259			if err != nil {
260				// On error, release our create hold
261				<-cp.createsem
262			} else {
263				atomic.AddUint64(&cp.connCount, 1)
264			}
265			return rv, err
266		case <-t.C:
267			return nil, ErrTimeout
268		}
269	}
270}
271
272func (cp *connectionPool) Get() (*memcached.Client, error) {
273	return cp.GetWithTimeout(ConnPoolTimeout)
274}
275
276func (cp *connectionPool) Return(c *memcached.Client) {
277	if c == nil {
278		return
279	}
280
281	if cp == nil {
282		c.Close()
283	}
284
285	if c.IsHealthy() {
286		defer func() {
287			if recover() != nil {
288				// This happens when the pool has already been
289				// closed and we're trying to return a
290				// connection to it anyway.  Just close the
291				// connection.
292				c.Close()
293			}
294		}()
295
296		select {
297		case cp.connections <- c:
298		default:
299			<-cp.createsem
300			c.Close()
301		}
302	} else {
303		<-cp.createsem
304		c.Close()
305	}
306}
307
308// give the ability to discard a connection from a pool
309// useful for ditching connections to the wrong node after a rebalance
310func (cp *connectionPool) Discard(c *memcached.Client) {
311	<-cp.createsem
312	c.Close()
313}
314
315// asynchronous connection closer
316func (cp *connectionPool) connCloser() {
317	var connCount uint64
318
319	t := time.NewTimer(ConnCloserInterval)
320	defer t.Stop()
321
322	for {
323		connCount = cp.connCount
324
325		// we don't exist anymore! bail out!
326		select {
327		case <-cp.bailOut:
328			return
329		case <-t.C:
330		}
331		t.Reset(ConnCloserInterval)
332
333		// no overflow connections open or sustained requests for connections
334		// nothing to do until the next cycle
335		if len(cp.connections) <= cp.poolSize ||
336			ConnCloserInterval/ConnPoolAvailWaitTime < time.Duration(cp.connCount-connCount) {
337			continue
338		}
339
340		// close overflow connections now that they are not needed
341		for c := range cp.connections {
342			select {
343			case <-cp.bailOut:
344				return
345			default:
346			}
347
348			// bail out if close did not work out
349			if !cp.connCleanup(c) {
350				return
351			}
352			if len(cp.connections) <= cp.poolSize {
353				break
354			}
355		}
356	}
357}
358
359// close connection with recovery on error
360func (cp *connectionPool) connCleanup(c *memcached.Client) (rv bool) {
361
362	// just in case we are closing a connection after
363	// bailOut has been sent but we haven't yet read it
364	defer func() {
365		if recover() != nil {
366			rv = false
367		}
368	}()
369	rv = true
370
371	c.Close()
372	<-cp.createsem
373	return
374}
375
376func (cp *connectionPool) StartTapFeed(args *memcached.TapArguments) (*memcached.TapFeed, error) {
377	if cp == nil {
378		return nil, errNoPool
379	}
380	mc, err := cp.Get()
381	if err != nil {
382		return nil, err
383	}
384
385	// A connection can't be used after TAP; Dont' count it against the
386	// connection pool capacity
387	<-cp.createsem
388
389	return mc.StartTapFeed(*args)
390}
391
392const DEFAULT_WINDOW_SIZE = 20 * 1024 * 1024 // 20 Mb
393
394func (cp *connectionPool) StartUprFeed(name string, sequence uint32, dcp_buffer_size uint32, data_chan_size int) (*memcached.UprFeed, error) {
395	if cp == nil {
396		return nil, errNoPool
397	}
398	mc, err := cp.Get()
399	if err != nil {
400		return nil, err
401	}
402
403	// A connection can't be used after it has been allocated to UPR;
404	// Dont' count it against the connection pool capacity
405	<-cp.createsem
406
407	uf, err := mc.NewUprFeed()
408	if err != nil {
409		return nil, err
410	}
411
412	if err := uf.UprOpen(name, sequence, dcp_buffer_size); err != nil {
413		return nil, err
414	}
415
416	if err := uf.StartFeedWithConfig(data_chan_size); err != nil {
417		return nil, err
418	}
419
420	return uf, nil
421}
422