1#!/usr/bin/env python
2from threading import Thread, RLock
3from time import sleep
4import socket
5
6from lib.mc_bin_client import MemcachedClient
7from lib.perf_engines.sys_helper import SocketHelper, synchronized
8
9from obs import Observer
10from obs_req import ObserveRequestKey, ObserveRequest
11from obs_res import ObserveResponse
12from obs_def import ObservePktFmt, ObserveStatus, ObserveKeyState
13from obs_helper import VbucketHelper
14
15BACKOFF = 0.2
16MAX_BACKOFF = 1
17
18class McsodaObserver(Observer, Thread):
19
20    ctl = None
21    cfg = None
22    store = None
23    awareness = None
24    conns = {}
25    obs_keys = {}   # {server: [keys]}
26    callback = None
27
28    #TODO: handle persist_count != 1
29    #TODO: socket timeout, fine-grained exceptions
30    #TODO: network helper
31    #TODO: wait call timeout
32
33    def __init__(self, ctl, cfg, store, callback):
34        self.ctl = ctl
35        self.cfg = cfg
36        self.store = store
37        self.callback = callback
38        self.conn_lock = RLock()
39        self._build_conns()
40        self.backoff = self.cfg.get('obs-backoff', BACKOFF)
41        self.max_backoff = self.cfg.get('obs-max-backoff', MAX_BACKOFF)
42        super(McsodaObserver, self).__init__()
43
44    def run(self):
45        while self.ctl['run_ok']:
46            self.observe()
47            try:
48                self.observable_filter(ObserveStatus.OBS_UNKNOWN).next()
49                print "<%s> sleep for %f seconds" % (self.__class__.__name__, self.backoff)
50                sleep(self.backoff)
51                self.backoff = min(self.backoff * 2, self.max_backoff)
52            except StopIteration:
53                self.measure_client_latency()
54                self.clear_observables()
55                if self.callback:
56                    self.callback(self.store)
57                self.backoff = self.cfg.get('obs-backoff', BACKOFF)
58        print "<%s> stopped running" % (self.__class__.__name__)
59
60    @synchronized("conn_lock")
61    def _build_conns(self):
62        """build separate connections based on store"""
63        if not self.store:
64            print "<%s> failed to build connections, invalid store object"\
65                % self.__class__.__name__
66            return False
67
68        if self.store.__class__.__name__ == "StoreMemcachedBinary":
69            conn = MemcachedClient(self.store.conn.host, self.store.conn.port)
70            server_str = "{0}:{1}".format(self.store.conn.host, self.store.conn.port)
71            self.conns[server_str] = conn
72        elif self.store.__class__.__name__ == "StoreMembaseBinary":
73            for memcached in self.store.awareness.memcacheds.itervalues():
74                conn = MemcachedClient(memcached.host, memcached.port)
75                server_str = "{0}:{1}".format(conn.host, conn.port)
76                self.conns[server_str] = conn
77            self.awareness = self.store.awareness
78        else:
79            print "<%s> error: unsupported store object %s" %\
80                  (self.__class__.__name__, store.__class__.__name__)
81            return False
82
83        return True
84
85    @synchronized("conn_lock")
86    def _refresh_conns(self):
87        """blocking call to refresh connections based on topology change"""
88        if not self.store:
89            print "<%s> failed to refresh connections, invalid store object"\
90                % self.__class__.__name__
91            return False
92
93        print "<%s> refreshing connections" % self.__class__.__name__
94
95        if self.store.__class__.__name__ == "StoreMembaseBinary":
96            old_keys = set(self.conns)
97            new_keys = set(self.store.awareness.memcacheds)
98
99            for del_server in old_keys.difference(new_keys):
100                print "<%s> _refresh_conns: delete server: %s" \
101                    % (self.__class__.__name__, del_server)
102                del self.conns[del_server]
103
104            for add_server in new_keys.difference(old_keys):
105                print "<%s> _refresh_conns: add server: %s" \
106                    % (self.__class__.__name__, add_server)
107                self._add_conn(add_server)
108
109            self.awareness = self.store.awareness
110
111        return True
112
113    @synchronized("conn_lock")
114    def _add_conn(self, server):
115        if not self.store:
116            print "<%s> failed to add conn, invalid store object"\
117                % self.__class__.__name__
118            return False
119
120        if self.store.__class__.__name__ == "StoreMembaseBinary":
121            print "<%s> _add_conn: %s"\
122                % (self.__class__.__name__, server)
123            host, port = server.split(":")
124            conn = MemcachedClient(host, int(port))
125            self.conns[server] = conn
126
127        return True
128
129    @synchronized("conn_lock")
130    def _reconnect(self, conn):
131        if not conn or\
132            conn.__class__.__name__ != "MemcachedClient":
133            print "<%s> failed to reconnect, invalid connection object"\
134                % self.__class__.__name__
135            return False
136
137        return conn.reconnect()
138
139    def _send(self):
140        self.obs_keys.clear()   # {server: [keys]}
141
142        observables = self.observable_filter(ObserveStatus.OBS_UNKNOWN)
143        with self._observables.mutex:
144            for obs in observables:
145                vbucketid = VbucketHelper.get_vbucket_id(obs.key, self.cfg.get("vbuckets", 0))
146                obs_key = ObserveRequestKey(obs.key, vbucketid)
147                if obs.persist_count > 0:
148                    persist_server = self._get_server_str(vbucketid)
149                    vals = self.obs_keys.get(persist_server, [])
150                    vals.append(obs_key)
151                    self.obs_keys[persist_server] = vals
152                    if not obs.persist_servers:
153                        obs.persist_servers.add(persist_server)
154                        self._observables.put(obs.key, obs)
155                if obs.repl_count > 0:
156                    repl_servers = self._get_server_str(vbucketid, repl=True)
157                    if len(repl_servers) < obs.repl_count:
158                        print "<%s> not enough number of replication servers to observe"\
159                            % self.__class__.__name__
160                        obs.status = ObserveStatus.OBS_ERROR # mark out this key
161                        self._observables.put(obs.key, obs)
162                        continue
163                    if not obs.repl_servers:
164                        obs.repl_servers.update(repl_servers)
165                        self._observables.put(obs.key, obs)
166                    for server in obs.repl_servers:
167                        vals = self.obs_keys.get(server, [])
168                        vals.append(obs_key)
169                        self.obs_keys[server] = vals
170
171        reqs = []
172        for server, keys in self.obs_keys.iteritems():
173            req = ObserveRequest(keys)
174            pkt = req.pack()
175            try:
176                self.conns[server].s.send(pkt)
177            except KeyError as e:
178                print "<%s> failed to send observe pkt : %s" % (self.__class__.__name__, e)
179                self._add_conn(server)
180                return None
181            except Exception as e:
182                print "<%s> failed to send observe pkt : %s" % (self.__class__.__name__, e)
183                self._refresh_conns()
184                return None
185            reqs.append(req)
186
187        print "reqs::"
188        print reqs
189        return reqs
190
191    def _recv(self):
192
193        responses = {}      # {server: [responses]}
194        for server in self.obs_keys.iterkeys():
195            hdr = ''
196            while len(hdr) < ObservePktFmt.OBS_RES_HDR_LEN:
197                try:
198                    hdr += self.conns[server].s.recv(ObservePktFmt.OBS_RES_HDR_LEN)
199                except KeyError as e:
200                    print "<%s> failed to recv observe pkt : %s" % (self.__class__.__name__, e)
201                    self._add_conn(server)
202                    return None
203                except Exception as e:
204                    print "<%s> failed to recv observe pkt: %s" % (self.__class__.__name__, e)
205                    self._refresh_conns()
206                    return None
207            res = ObserveResponse()
208
209            if not res.unpack_hdr(hdr):
210                if res.status == ERR_NOT_MY_VBUCKET:
211                    self._refresh_conns()
212                return None
213
214            body = ''
215            while len(body) < res.body_len:
216                body += self.conns[server].s.recv(res.body_len)
217            res.unpack_body(body)
218
219            # TODO: error check
220
221            self.save_latency_stats(res.persist_stat/1000)
222
223            print "res::<%s>" % server
224            print res
225            vals = responses.get(server, [])
226            vals.append(res)
227            responses[server] = vals
228
229        return responses
230
231    def _get_server_str(self, vbucketid, repl=False):
232        """retrieve server string {ip:port} based on vbucketid"""
233        if self.awareness:
234            if repl:
235                server = self.awareness.vBucketMapReplica[vbucketid]
236            else:
237                server = self.awareness.vBucketMap[vbucketid]
238            return server
239        elif len(self.conns) and not repl:
240            return self.conns.iterkeys().next()
241
242        return None
243
244    def block_for_persistence(self, key, cas, server="", timeout=0):
245        """
246        observe a key until it has been persisted
247        """
248        self.backoff = self.cfg.get('obs-backoff', BACKOFF)
249
250        while True:
251
252            status, new_cas = self.observe_single(key, server, timeout)
253
254            if status < 0:
255                return False
256
257            if new_cas != cas:
258                print "<%s> block_for_persistence: key: %s, "\
259                    "cas: %s has been modified"\
260                    % (self.__class__.__name__, key, cas)
261                return False
262
263            elif status == ObserveKeyState.OBS_PERSISITED:
264                return True
265            elif status == ObserveKeyState.OBS_FOUND:
266                sleep(self.backoff)
267                self.backoff = min(self.backoff * 2, self.max_backoff)
268                continue
269            elif status == ObserveKeyState.OBS_NOT_FOUND:
270                print "<%s> block_for_persistence: key: %s, cas: %s does not" \
271                    " exist any more" % (self.__class__.__name__, key, cas)
272                return False
273            else:
274                print "<%s> block_for_persistence: invalid key state: %x" \
275                    % (self.__class__.__name__, res_key.key_state)
276                return False
277
278        return False # unreachable
279
280    def block_for_replication(self, key, cas, num=1, timeout=0, persist=False):
281        """
282        observe a key until it has been replicated to @param num of servers
283
284        @param persist : block until item has been persisted to disk
285        """
286        if not isinstance(num, int) or num <= 0:
287            print "<%s> block_for_replication: invalid num %s" \
288                % (self.__class__.__name__, num)
289            return False
290
291        vbucketid = \
292            VbucketHelper.get_vbucket_id(key, self.cfg.get("vbuckets", 0))
293
294        repl_servers = self._get_server_str(vbucketid, repl=True)
295
296        if persist and not self.block_for_persistence(key, cas):
297            return False
298
299        self.backoff = self.cfg.get('obs-backoff', BACKOFF)
300
301        print "<%s> block_for_replication: repl_servers: %s,"\
302            " key: %s, cas: %s, vbucketid: %s" \
303            % (self.__class__.__name__, repl_servers, key, cas, vbucketid)
304
305        while len(repl_servers) >= num > 0:
306
307            for server in repl_servers:
308
309                if num == 0:
310                    break
311
312                status, new_cas = self.observe_single(key, server, timeout)
313
314                if status < 0:
315                    repl_servers.remove(server)
316                    continue
317
318                if new_cas and new_cas != cas:
319                    # Due to the current protocol limitations,
320                    # assume key is unique and new, skip this server
321                    repl_servers.remove(server)
322                    continue
323
324                if status == ObserveKeyState.OBS_PERSISITED:
325                    num -= 1
326                    repl_servers.remove(server)
327                    continue
328                elif status == ObserveKeyState.OBS_FOUND:
329                    if not persist:
330                        num -= 1
331                        repl_servers.remove(server)
332                        continue
333                elif status == ObserveKeyState.OBS_NOT_FOUND:
334                    pass
335
336                if len(repl_servers) == 1:
337                    sleep(self.backoff)
338                    self.backoff = min(self.backoff * 2, self.max_backoff)
339
340        if num > 0:
341            return False
342
343        return True
344
345    def observe_single(self, key, server="", timeout=0):
346        """
347        send an observe command and get the response back
348
349        parse the response afterwards
350
351        @return (status, cas)
352
353        @status -1 : network error
354        @status -2 : protocol error
355        @status ObserveKeyState
356        """
357        cas = ""
358        if not key:
359            print "<%s> observe_single: invalid key" % self.__class__.__name__
360            return -1
361
362        vbucketid = \
363            VbucketHelper.get_vbucket_id(key, self.cfg.get("vbuckets", 0))
364        if not server:
365            server = self._get_server_str(vbucketid)
366        req_key = ObserveRequestKey(key, vbucketid)
367
368        req = ObserveRequest([req_key])
369        pkt = req.pack()
370
371        try:
372            skt = self.conns[server].s
373        except KeyError:
374            print "<%s> observe_single: KeyError: %s" \
375                % (self.__class__.__name__, server)
376            self._add_conn(server)
377            return -1, cas
378
379        try:
380            SocketHelper.send_bytes(skt, pkt, timeout)
381        except IOError:
382            print "<%s> observe_single: IOError: " \
383                  "failed to send observe pkt : %s" \
384                  % (self.__class__.__name__, pkt)
385            self._reconnect(self.conns[server])
386            self._refresh_conns()
387            return -1, cas
388        except socket.timeout:
389            print "<%s> observe_single: timeout: " \
390                "failed to send observe pkt : %s" \
391                % (self.__class__.__name__, pkt)
392            return -1, cas
393        except Exception as e:
394            print "<%s> observe_single: failed to send observe pkt : %s" \
395                % (self.__class__.__name__, e)
396            return -1, cas
397
398        try:
399            hdr = SocketHelper.recv_bytes(
400                skt, ObservePktFmt.OBS_RES_HDR_LEN, timeout)
401            res = ObserveResponse()
402            if not res.unpack_hdr(hdr):
403                if res.status == ERR_NOT_MY_VBUCKET:
404                    self._refresh_conns()
405                return -1, cas
406            body = SocketHelper.recv_bytes(skt, res.body_len, timeout)
407            res.unpack_body(body)
408        except IOError:
409            print "<%s> observe_single: IOError: failed to recv observe pkt" \
410                % self.__class__.__name__
411            self._reconnect(self.conns[server])
412            self._refresh_conns()
413            return -1, cas
414        except socket.timeout:
415            print "<%s> observe_single: timeout: failed to recv observe pkt" \
416                % self.__class__.__name__
417            return -1, cas
418        except Exception as e:
419            print "<%s> observe_single: failed to recv observe pkt : %s" \
420                % (self.__class__.__name__, e)
421            return -1, cas
422
423        if not res:
424            print "<%s> observe_single: empty response" \
425                % self.__class__.__name__
426            return -1, cas
427
428        key_len = len(res.keys)
429        if key_len != 1:
430            # we are not supposed to receive responses for more than one key,
431            # otherwise, it's a server side protocol error
432            print "<%s> observe_single: invalid number of keys in response: %d"\
433                    % (self.s.__name__, key_len)
434            return -2, cas
435
436        res_key = res.keys[0]
437        cas = res_key.cas
438
439        if res_key.key != key:
440            print "<%s> observe_single: invalid key %s in response"\
441                % self.__class__.__name__
442            return -2, cas
443
444        return res_key.key_state, cas
445
446    def measure_client_latency(self):
447        observables = self.observable_filter(ObserveStatus.OBS_SUCCESS)
448        for obs in observables:
449            persist_dur = obs.persist_end_time - obs.start_time
450            repl_dur = obs.repl_end_time - obs.start_time
451            print "<%s> saving client latency, "\
452                  "key: %s, cas: %s, persist_dur: %f, repl_dur: %f"\
453                  % (self.__class__.__name__, obs.key, obs.cas,
454                     persist_dur, repl_dur)
455            if persist_dur > 0:
456                self.save_latency_stats(persist_dur, obs.start_time, False)
457            if repl_dur > 0:
458                self.save_latency_stats(persist_dur, obs.start_time,
459                                        server=False, repl=True)
460
461    def save_latency_stats(self, latency, time=0, server=True, repl=False):
462        if not latency:
463            return False    # TODO: simply skip 0
464
465        if server:
466            self.store.add_timing_sample("obs-persist-server", float(latency))
467        else:
468            if repl:
469                cmd = "obs-repl-client" # TODO: # of replicas
470            else:
471                cmd = "obs-persist-client"
472            self.store.add_timing_sample(cmd, float(latency))
473
474        if self.store.sc:
475            self.store.save_stats(time)
476
477        return True
478