1package couchbase
2
3import (
4	"errors"
5	"fmt"
6	"time"
7
8	"encoding/binary"
9
10	"github.com/couchbase/indexing/secondary/dcp/transport"
11	"github.com/couchbase/indexing/secondary/dcp/transport/client"
12)
13
14var errClosedPool = errors.New("the pool is closed")
15var errNoPool = errors.New("no pool")
16
17// GenericMcdAuthHandler is a kind of AuthHandler that performs
18// special auth exchange (like non-standard auth, possibly followed by
19// select-bucket).
20type GenericMcdAuthHandler interface {
21	AuthHandler
22	AuthenticateMemcachedConn(string, *memcached.Client) error
23}
24
25// Default timeout for retrieving a connection from the pool.
26var ConnPoolTimeout = time.Hour * 24 * 30
27
28// ConnPoolAvailWaitTime is the amount of time to wait for an existing
29// connection from the pool before considering the creation of a new
30// one.
31var ConnPoolAvailWaitTime = time.Millisecond
32
33type connectionPool struct {
34	host        string
35	mkConn      func(host string, ah AuthHandler) (*memcached.Client, error)
36	auth        AuthHandler
37	connections chan *memcached.Client
38	createsem   chan bool
39}
40
41func newConnectionPool(host string, ah AuthHandler, poolSize, poolOverflow int) *connectionPool {
42	return &connectionPool{
43		host:        host,
44		connections: make(chan *memcached.Client, poolSize),
45		createsem:   make(chan bool, poolSize+poolOverflow),
46		mkConn:      defaultMkConn,
47		auth:        ah,
48	}
49}
50
51// ConnPoolTimeout is notified whenever connections are acquired from a pool.
52var ConnPoolCallback func(host string, source string, start time.Time, err error)
53
54func defaultMkConn(
55	host string, ah AuthHandler) (conn *memcached.Client, err error) {
56
57	conn, err = memcached.Connect("tcp", host)
58	if err != nil {
59		return nil, err
60	}
61
62	defer func() {
63		if err != nil {
64			conn.Close()
65			conn = nil
66			return
67		}
68	}()
69
70	if gah, ok := ah.(GenericMcdAuthHandler); ok {
71		err = gah.AuthenticateMemcachedConn(host, conn)
72		return
73	}
74	name, pass := ah.GetCredentials()
75	if name != "default" {
76		_, err = conn.Auth(name, pass)
77	}
78	return
79}
80
81func (cp *connectionPool) Close() (err error) {
82	defer func() { err, _ = recover().(error) }()
83	close(cp.connections)
84	for c := range cp.connections {
85		c.Close()
86	}
87	return
88}
89
90func (cp *connectionPool) GetWithTimeout(d time.Duration) (rv *memcached.Client, err error) {
91	if cp == nil {
92		return nil, errNoPool
93	}
94
95	path := ""
96
97	if ConnPoolCallback != nil {
98		defer func(path *string, start time.Time) {
99			ConnPoolCallback(cp.host, *path, start, err)
100		}(&path, time.Now())
101	}
102
103	path = "short-circuit"
104
105	// short-circuit available connetions.
106	select {
107	case rv, isopen := <-cp.connections:
108		if !isopen {
109			return nil, errClosedPool
110		}
111		return rv, nil
112	default:
113	}
114
115	t := time.NewTimer(ConnPoolAvailWaitTime)
116	defer t.Stop()
117
118	// Try to grab an available connection within 1ms
119	select {
120	case rv, isopen := <-cp.connections:
121		path = "avail1"
122		if !isopen {
123			return nil, errClosedPool
124		}
125		return rv, nil
126	case <-t.C:
127		// No connection came around in time, let's see
128		// whether we can get one or build a new one first.
129		t.Reset(d) // Reuse the timer for the full timeout.
130		select {
131		case rv, isopen := <-cp.connections:
132			path = "avail2"
133			if !isopen {
134				return nil, errClosedPool
135			}
136			return rv, nil
137		case cp.createsem <- true:
138			path = "create"
139			// Build a connection if we can't get a real one.
140			// This can potentially be an overflow connection, or
141			// a pooled connection.
142			rv, err := cp.mkConn(cp.host, cp.auth)
143			if err != nil {
144				// On error, release our create hold
145				<-cp.createsem
146			}
147			return rv, err
148		case <-t.C:
149			return nil, ErrTimeout
150		}
151	}
152}
153
154func (cp *connectionPool) Get() (*memcached.Client, error) {
155	return cp.GetWithTimeout(ConnPoolTimeout)
156}
157
158func (cp *connectionPool) Return(c *memcached.Client) {
159	if c == nil {
160		return
161	}
162
163	if cp == nil {
164		c.Close()
165	}
166
167	if c.IsHealthy() {
168		defer func() {
169			if recover() != nil {
170				// This happens when the pool has already been
171				// closed and we're trying to return a
172				// connection to it anyway.  Just close the
173				// connection.
174				c.Close()
175			}
176		}()
177
178		select {
179		case cp.connections <- c:
180		default:
181			// Overflow connection.
182			<-cp.createsem
183			c.Close()
184		}
185	} else {
186		<-cp.createsem
187		c.Close()
188	}
189}
190
191func (cp *connectionPool) StartTapFeed(args *memcached.TapArguments) (*memcached.TapFeed, error) {
192	if cp == nil {
193		return nil, errNoPool
194	}
195	mc, err := cp.Get()
196	if err != nil {
197		return nil, err
198	}
199
200	// A connection can't be used after TAP; Dont' count it against the
201	// connection pool capacity
202	<-cp.createsem
203
204	return mc.StartTapFeed(*args)
205}
206
207const DEFAULT_WINDOW_SIZE = uint32(20 * 1024 * 1024) // 20 Mb
208
209func (cp *connectionPool) StartDcpFeed(
210	name DcpFeedName, sequence, flags uint32,
211	outch chan *memcached.DcpEvent,
212	opaque uint16,
213	supvch chan []interface{},
214	config map[string]interface{}) (*memcached.DcpFeed, error) {
215
216	if cp == nil {
217		return nil, errNoPool
218	}
219
220	mc, err := cp.Get() // Don't call Return() on this
221	if err != nil {
222		return nil, err
223	}
224	// A connection can't be used after it has been allocated to DCP;
225	// Dont' count it against the connection pool capacity
226	<-cp.createsem
227
228	dcpf, err := memcached.NewDcpFeed(mc, string(name), outch, opaque, supvch, config)
229	if err == nil {
230		err = dcpf.DcpOpen(
231			string(name), sequence, flags, DEFAULT_WINDOW_SIZE, opaque,
232		)
233		if err == nil {
234			return dcpf, err
235		}
236	}
237	mc.Close()
238	return nil, err
239}
240
241func (cp *connectionPool) GetDcpConn(name DcpFeedName) (*memcached.Client, error) {
242	mc, err := cp.Get() // Don't call Return() on this
243	if err != nil {
244		return nil, err
245	}
246
247	rq := &transport.MCRequest{
248		Opcode: transport.DCP_OPEN,
249		Key:    []byte(string(name)),
250		Opaque: 0,
251	}
252	rq.Extras = make([]byte, 8)
253	binary.BigEndian.PutUint32(rq.Extras[:4], 0)
254	binary.BigEndian.PutUint32(rq.Extras[4:], 1) // we are consumer
255
256	mc.SetMcdConnectionDeadline()
257	defer mc.ResetMcdConnectionDeadline()
258
259	if err := mc.Transmit(rq); err != nil {
260		return nil, err
261	}
262
263	_, err = mc.Receive()
264	if err != nil {
265		return nil, err
266	}
267
268	return mc, nil
269}
270
271func GetSeqs(mc *memcached.Client, seqnos []uint64, buf []byte) error {
272	res := &transport.MCResponse{}
273	rq := &transport.MCRequest{
274		Opcode: transport.DCP_GET_SEQNO,
275		Opaque: 0,
276	}
277
278	rq.Extras = make([]byte, 4)
279	binary.BigEndian.PutUint32(rq.Extras, 1) // Only active vbuckets
280
281	mc.SetMcdConnectionDeadline()
282	defer mc.ResetMcdConnectionDeadline()
283
284	if err := mc.Transmit(rq); err != nil {
285		return err
286	}
287
288	if err := mc.ReceiveInBuf(res, buf); err != nil {
289		return err
290	}
291
292	if res.Status != transport.SUCCESS {
293		return fmt.Errorf("failed %d", res.Status)
294	}
295
296	if len(res.Body)%10 != 0 {
297		fmsg := "invalid body length %v, in get-seqnos\n"
298		err := fmt.Errorf(fmsg, len(res.Body))
299		return err
300	}
301	for i := 0; i < 1024; i++ {
302		seqnos[i] = 0
303	}
304
305	for i := 0; i < len(res.Body); i += 10 {
306		vbno := int(binary.BigEndian.Uint16(res.Body[i : i+2]))
307		seqno := binary.BigEndian.Uint64(res.Body[i+2 : i+10])
308		seqnos[vbno] = seqno
309	}
310
311	return nil
312}
313
314//Get seqnos of vbuckets in active/replica/pending state
315func GetSeqsAllVbStates(mc *memcached.Client, seqnos []uint64, buf []byte) error {
316	res := &transport.MCResponse{}
317	rq := &transport.MCRequest{
318		Opcode: transport.DCP_GET_SEQNO,
319		Opaque: 0,
320	}
321
322	mc.SetMcdConnectionDeadline()
323	defer mc.ResetMcdConnectionDeadline()
324
325	if err := mc.Transmit(rq); err != nil {
326		return err
327	}
328
329	if err := mc.ReceiveInBuf(res, buf); err != nil {
330		return err
331	}
332
333	if res.Status != transport.SUCCESS {
334		return fmt.Errorf("failed %d", res.Status)
335	}
336
337	if len(res.Body)%10 != 0 {
338		fmsg := "invalid body length %v, in get-seqnos\n"
339		err := fmt.Errorf(fmsg, len(res.Body))
340		return err
341	}
342	for i := 0; i < 1024; i++ {
343		seqnos[i] = 0
344	}
345
346	for i := 0; i < len(res.Body); i += 10 {
347		vbno := int(binary.BigEndian.Uint16(res.Body[i : i+2]))
348		seqno := binary.BigEndian.Uint64(res.Body[i+2 : i+10])
349		seqnos[vbno] = seqno
350	}
351
352	return nil
353}
354